From a28217819482668d61e3877e5e3963da8eba8a9e Mon Sep 17 00:00:00 2001
From: Hansong <107070759+kirklandsign@users.noreply.github.com>
Date: Wed, 15 May 2024 12:04:18 -0700
Subject: [PATCH] [Android] Add a sample instrumentedTest (#800)

* [Android] Add a sample instrumentedTest

We can't add a pure Java unit test because we have Android JNI deps. We
use instrumented test instead. This is runnable if we have an emulator
on host.

Test: pushd android/Torchchat/; ./gradlew connectedAndroidTest; popd

* Add docs for LlamaModuleTest.java
---
 .../torchchat/ExampleInstrumentedTest.java    | 26 ----------
 .../pytorch/torchchat/LlamaModuleTest.java    | 49 +++++++++++++++++++
 2 files changed, 49 insertions(+), 26 deletions(-)
 delete mode 100644 android/Torchchat/app/src/androidTest/java/org/pytorch/torchchat/ExampleInstrumentedTest.java
 create mode 100644 android/Torchchat/app/src/androidTest/java/org/pytorch/torchchat/LlamaModuleTest.java

diff --git a/android/Torchchat/app/src/androidTest/java/org/pytorch/torchchat/ExampleInstrumentedTest.java b/android/Torchchat/app/src/androidTest/java/org/pytorch/torchchat/ExampleInstrumentedTest.java
deleted file mode 100644
index 06cc6606c2..0000000000
--- a/android/Torchchat/app/src/androidTest/java/org/pytorch/torchchat/ExampleInstrumentedTest.java
+++ /dev/null
@@ -1,26 +0,0 @@
-package org.pytorch.torchchat;
-
-import android.content.Context;
-
-import androidx.test.platform.app.InstrumentationRegistry;
-import androidx.test.ext.junit.runners.AndroidJUnit4;
-
-import org.junit.Test;
-import org.junit.runner.RunWith;
-
-import static org.junit.Assert.*;
-
-/**
- * Instrumented test, which will execute on an Android device.
- *
- * @see <a href="http://d.android.com/tools/testing">Testing documentation</a>
- */
-@RunWith(AndroidJUnit4.class)
-public class ExampleInstrumentedTest {
-    @Test
-    public void useAppContext() {
-        // Context of the app under test.
-        Context appContext = InstrumentationRegistry.getInstrumentation().getTargetContext();
-        assertEquals("org.pytorch.torchchat", appContext.getPackageName());
-    }
-}
diff --git a/android/Torchchat/app/src/androidTest/java/org/pytorch/torchchat/LlamaModuleTest.java b/android/Torchchat/app/src/androidTest/java/org/pytorch/torchchat/LlamaModuleTest.java
new file mode 100644
index 0000000000..df8db8162d
--- /dev/null
+++ b/android/Torchchat/app/src/androidTest/java/org/pytorch/torchchat/LlamaModuleTest.java
@@ -0,0 +1,49 @@
+package org.pytorch.torchchat;
+
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.pytorch.executorch.LlamaCallback;
+import org.pytorch.executorch.LlamaModule;
+
+import static org.junit.Assert.*;
+
+/**
+ * Instrumented test, which will execute on an Android device.
+ *
+ * @see <a href="http://d.android.com/tools/testing">Testing documentation</a>
+ */
+@RunWith(AndroidJUnit4.class)
+public class LlamaModuleTest {
+    @Test
+    public void LlamaModule() {
+        LlamaModule module = new LlamaModule("/data/local/tmp/llm/model.pte", "/data/local/tmp/llm/tokenizer.bin", 0.8f);
+        assertEquals(module.load(), 0);
+        MyLlamaCallback callback = new MyLlamaCallback();
+        // Note: module.generate() is synchronous. Callback happens within the same thread as
+        // generate() so when generate() returns, all callbacks are invoked.
+        assertEquals(module.generate("Hey", callback), 0);
+        assertNotEquals("", callback.result);
+    }
+}
+
+/**
+ * LlamaCallback for testing.
+ *
+ * Note: onResult() and onStats() are invoked within the same thread as LlamaModule.generate()
+ *
+ * @see <a href="https://github.com/pytorch/executorch/blob/main/extension/android/src/main/java/org/pytorch/executorch/LlamaCallback.java">LlamaCallback interface guide</a>
+ */
+class MyLlamaCallback implements LlamaCallback {
+    String result = "";
+    @Override
+    public void onResult(String s) {
+        result += s;
+    }
+
+    @Override
+    public void onStats(float v) {
+
+    }
+}