Skip to content

Commit

Permalink
[Android] Add a sample instrumentedTest (pytorch#800)
Browse files Browse the repository at this point in the history
* [Android] Add a sample instrumentedTest

We can't add a pure Java unit test because we have Android JNI deps. We
use instrumented test instead. This is runnable if we have an emulator
on host.

Test: pushd android/Torchchat/; ./gradlew connectedAndroidTest; popd

* Add docs for LlamaModuleTest.java
  • Loading branch information
kirklandsign authored and malfet committed Jul 17, 2024
1 parent 358931d commit a282178
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 26 deletions.

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package org.pytorch.torchchat;

import androidx.test.ext.junit.runners.AndroidJUnit4;

import org.junit.Test;
import org.junit.runner.RunWith;
import org.pytorch.executorch.LlamaCallback;
import org.pytorch.executorch.LlamaModule;

import static org.junit.Assert.*;

/**
* Instrumented test, which will execute on an Android device.
*
* @see <a href="http://d.android.com/tools/testing">Testing documentation</a>
*/
@RunWith(AndroidJUnit4.class)
public class LlamaModuleTest {
@Test
public void LlamaModule() {
LlamaModule module = new LlamaModule("/data/local/tmp/llm/model.pte", "/data/local/tmp/llm/tokenizer.bin", 0.8f);
assertEquals(module.load(), 0);
MyLlamaCallback callback = new MyLlamaCallback();
// Note: module.generate() is synchronous. Callback happens within the same thread as
// generate() so when generate() returns, all callbacks are invoked.
assertEquals(module.generate("Hey", callback), 0);
assertNotEquals("", callback.result);
}
}

/**
* LlamaCallback for testing.
*
* Note: onResult() and onStats() are invoked within the same thread as LlamaModule.generate()
*
* @see <a href="https://github.com/pytorch/executorch/blob/main/extension/android/src/main/java/org/pytorch/executorch/LlamaCallback.java">LlamaCallback interface guide</a>
*/
class MyLlamaCallback implements LlamaCallback {
String result = "";
@Override
public void onResult(String s) {
result += s;
}

@Override
public void onStats(float v) {

}
}

0 comments on commit a282178

Please sign in to comment.