From a28217819482668d61e3877e5e3963da8eba8a9e Mon Sep 17 00:00:00 2001 From: Hansong <107070759+kirklandsign@users.noreply.github.com> Date: Wed, 15 May 2024 12:04:18 -0700 Subject: [PATCH] [Android] Add a sample instrumentedTest (#800) * [Android] Add a sample instrumentedTest We can't add a pure Java unit test because we have Android JNI deps. We use instrumented test instead. This is runnable if we have an emulator on host. Test: pushd android/Torchchat/; ./gradlew connectedAndroidTest; popd * Add docs for LlamaModuleTest.java --- .../torchchat/ExampleInstrumentedTest.java | 26 ---------- .../pytorch/torchchat/LlamaModuleTest.java | 49 +++++++++++++++++++ 2 files changed, 49 insertions(+), 26 deletions(-) delete mode 100644 android/Torchchat/app/src/androidTest/java/org/pytorch/torchchat/ExampleInstrumentedTest.java create mode 100644 android/Torchchat/app/src/androidTest/java/org/pytorch/torchchat/LlamaModuleTest.java diff --git a/android/Torchchat/app/src/androidTest/java/org/pytorch/torchchat/ExampleInstrumentedTest.java b/android/Torchchat/app/src/androidTest/java/org/pytorch/torchchat/ExampleInstrumentedTest.java deleted file mode 100644 index 06cc6606c2..0000000000 --- a/android/Torchchat/app/src/androidTest/java/org/pytorch/torchchat/ExampleInstrumentedTest.java +++ /dev/null @@ -1,26 +0,0 @@ -package org.pytorch.torchchat; - -import android.content.Context; - -import androidx.test.platform.app.InstrumentationRegistry; -import androidx.test.ext.junit.runners.AndroidJUnit4; - -import org.junit.Test; -import org.junit.runner.RunWith; - -import static org.junit.Assert.*; - -/** - * Instrumented test, which will execute on an Android device. - * - * @see <a href="http://d.android.com/tools/testing">Testing documentation</a> - */ -@RunWith(AndroidJUnit4.class) -public class ExampleInstrumentedTest { - @Test - public void useAppContext() { - // Context of the app under test. - Context appContext = InstrumentationRegistry.getInstrumentation().getTargetContext(); - assertEquals("org.pytorch.torchchat", appContext.getPackageName()); - } -} diff --git a/android/Torchchat/app/src/androidTest/java/org/pytorch/torchchat/LlamaModuleTest.java b/android/Torchchat/app/src/androidTest/java/org/pytorch/torchchat/LlamaModuleTest.java new file mode 100644 index 0000000000..df8db8162d --- /dev/null +++ b/android/Torchchat/app/src/androidTest/java/org/pytorch/torchchat/LlamaModuleTest.java @@ -0,0 +1,49 @@ +package org.pytorch.torchchat; + +import androidx.test.ext.junit.runners.AndroidJUnit4; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.pytorch.executorch.LlamaCallback; +import org.pytorch.executorch.LlamaModule; + +import static org.junit.Assert.*; + +/** + * Instrumented test, which will execute on an Android device. + * + * @see <a href="http://d.android.com/tools/testing">Testing documentation</a> + */ +@RunWith(AndroidJUnit4.class) +public class LlamaModuleTest { + @Test + public void LlamaModule() { + LlamaModule module = new LlamaModule("/data/local/tmp/llm/model.pte", "/data/local/tmp/llm/tokenizer.bin", 0.8f); + assertEquals(module.load(), 0); + MyLlamaCallback callback = new MyLlamaCallback(); + // Note: module.generate() is synchronous. Callback happens within the same thread as + // generate() so when generate() returns, all callbacks are invoked. + assertEquals(module.generate("Hey", callback), 0); + assertNotEquals("", callback.result); + } +} + +/** + * LlamaCallback for testing. + * + * Note: onResult() and onStats() are invoked within the same thread as LlamaModule.generate() + * + * @see <a href="https://github.com/pytorch/executorch/blob/main/extension/android/src/main/java/org/pytorch/executorch/LlamaCallback.java">LlamaCallback interface guide</a> + */ +class MyLlamaCallback implements LlamaCallback { + String result = ""; + @Override + public void onResult(String s) { + result += s; + } + + @Override + public void onStats(float v) { + + } +}