forked from pytorch/ao
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Android] Add a sample instrumentedTest (pytorch#800)
* [Android] Add a sample instrumentedTest We can't add a pure Java unit test because we have Android JNI deps. We use instrumented test instead. This is runnable if we have an emulator on host. Test: pushd android/Torchchat/; ./gradlew connectedAndroidTest; popd * Add docs for LlamaModuleTest.java
- Loading branch information
1 parent
358931d
commit a282178
Showing
2 changed files
with
49 additions
and
26 deletions.
There are no files selected for viewing
26 changes: 0 additions & 26 deletions
26
...oid/Torchchat/app/src/androidTest/java/org/pytorch/torchchat/ExampleInstrumentedTest.java
This file was deleted.
Oops, something went wrong.
49 changes: 49 additions & 0 deletions
49
android/Torchchat/app/src/androidTest/java/org/pytorch/torchchat/LlamaModuleTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
package org.pytorch.torchchat; | ||
|
||
import androidx.test.ext.junit.runners.AndroidJUnit4; | ||
|
||
import org.junit.Test; | ||
import org.junit.runner.RunWith; | ||
import org.pytorch.executorch.LlamaCallback; | ||
import org.pytorch.executorch.LlamaModule; | ||
|
||
import static org.junit.Assert.*; | ||
|
||
/** | ||
* Instrumented test, which will execute on an Android device. | ||
* | ||
* @see <a href="http://d.android.com/tools/testing">Testing documentation</a> | ||
*/ | ||
@RunWith(AndroidJUnit4.class) | ||
public class LlamaModuleTest { | ||
@Test | ||
public void LlamaModule() { | ||
LlamaModule module = new LlamaModule("/data/local/tmp/llm/model.pte", "/data/local/tmp/llm/tokenizer.bin", 0.8f); | ||
assertEquals(module.load(), 0); | ||
MyLlamaCallback callback = new MyLlamaCallback(); | ||
// Note: module.generate() is synchronous. Callback happens within the same thread as | ||
// generate() so when generate() returns, all callbacks are invoked. | ||
assertEquals(module.generate("Hey", callback), 0); | ||
assertNotEquals("", callback.result); | ||
} | ||
} | ||
|
||
/** | ||
* LlamaCallback for testing. | ||
* | ||
* Note: onResult() and onStats() are invoked within the same thread as LlamaModule.generate() | ||
* | ||
* @see <a href="https://github.com/pytorch/executorch/blob/main/extension/android/src/main/java/org/pytorch/executorch/LlamaCallback.java">LlamaCallback interface guide</a> | ||
*/ | ||
class MyLlamaCallback implements LlamaCallback { | ||
String result = ""; | ||
@Override | ||
public void onResult(String s) { | ||
result += s; | ||
} | ||
|
||
@Override | ||
public void onStats(float v) { | ||
|
||
} | ||
} |