diff --git a/backend-lib/Cargo.lock b/backend-lib/Cargo.lock index f704c9866..1932f5c08 100644 --- a/backend-lib/Cargo.lock +++ b/backend-lib/Cargo.lock @@ -7037,9 +7037,9 @@ dependencies = [ [[package]] name = "zcash_voting" -version = "0.5.3" +version = "0.5.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90abca13341b344d82315895bf50f26e93a66b5c4d2fbd8591728af604d4d4db" +checksum = "e1bb9e0ae40320acb03358d257a6fdc951866a21fb5fdd94ce5a09d0bccf25d2" dependencies = [ "anyhow", "blake2b_simd", diff --git a/backend-lib/Cargo.toml b/backend-lib/Cargo.toml index ea309ceae..c2db05eac 100644 --- a/backend-lib/Cargo.toml +++ b/backend-lib/Cargo.toml @@ -86,14 +86,14 @@ rust-analyzer = "0.0.1" # The Android JNI boundary for voting code is `src/main/rust/voting.rs`. # Its share-nullifier entry point forwards to # `zcash_voting::share_tracking::compute_share_nullifier`: -# https://github.com/valargroup/zcash_voting/blob/zcash_voting-v0.5.3/zcash_voting/src/share_tracking.rs +# https://github.com/valargroup/zcash_voting/blob/zcash_voting-v0.5.9/zcash_voting/src/share_tracking.rs # # The transitive `voting-circuits` dependency contains the share-reveal circuit # and the Poseidon-based `share_nullifier_hash` implementation: # https://github.com/valargroup/voting-circuits/blob/v0.4.1/voting-circuits/src/share_reveal/circuit.rs # Default features stay disabled so this foundation change does not pull in the # optional client PIR, tree-sync, or networking stacks. -zcash_voting = { version = "0.5.3", default-features = false } +zcash_voting = { version = "0.5.9", default-features = false } ## Uncomment this to test librustzcash changes locally #[patch.crates-io] diff --git a/backend-lib/src/androidTest/java/cash/z/ecc/android/sdk/internal/jni/VotingRustBackendTest.kt b/backend-lib/src/androidTest/java/cash/z/ecc/android/sdk/internal/jni/VotingRustBackendTest.kt index 67885dac7..5ae97617d 100644 --- a/backend-lib/src/androidTest/java/cash/z/ecc/android/sdk/internal/jni/VotingRustBackendTest.kt +++ b/backend-lib/src/androidTest/java/cash/z/ecc/android/sdk/internal/jni/VotingRustBackendTest.kt @@ -1,12 +1,18 @@ package cash.z.ecc.android.sdk.internal.jni +import cash.z.ecc.android.sdk.internal.model.voting.JniRoundPhase import kotlinx.coroutines.test.runTest import org.junit.Test +import kotlin.io.path.createTempDirectory import kotlin.test.assertContentEquals +import kotlin.test.assertEquals import kotlin.test.assertFailsWith import kotlin.test.assertFalse +import kotlin.test.assertNotNull +import kotlin.test.assertNull @OptIn(ExperimentalStdlibApi::class) +@Suppress("MagicNumber") class VotingRustBackendTest { companion object { private const val FIELD_BYTES = 32 @@ -17,6 +23,15 @@ class VotingRustBackendTest { private val SHORT_FIELD = ByteArray(FIELD_BYTES - 1) private val EXPECTED_NULLIFIER = "8d6d97caa19a20e5e67e7cc24aaaa7beb72b4a513863f6adbe7b62ba1b1b0010".hexToByteArray() + + private const val WALLET_ID = "wallet-1" + private const val OTHER_WALLET_ID = "wallet-2" + private const val ROUND_ID = "round-1" + private const val SNAPSHOT_HEIGHT = 123_456L + private const val SESSION_JSON = "{\"round\":\"one\"}" + private val EA_PK = ByteArray(FIELD_BYTES) { 3 } + private val NC_ROOT = ByteArray(FIELD_BYTES) { 4 } + private val NULLIFIER_IMT_ROOT = ByteArray(FIELD_BYTES) { 5 } } @Test @@ -45,4 +60,106 @@ class VotingRustBackendTest { backend.computeShareNullifier(VOTE_COMMITMENT, OUT_OF_RANGE_SHARE_INDEX, BLIND) } } + + @Test + fun voting_db_round_state_round_trips() = + runTest { + val db = VotingRustBackend.new().openVotingDb(newDbPath(), WALLET_ID) + try { + assertNull(db.getRoundState(ROUND_ID)) + + db.initRound( + roundId = ROUND_ID, + snapshotHeight = SNAPSHOT_HEIGHT, + eaPK = EA_PK, + ncRoot = NC_ROOT, + nullifierIMTRoot = NULLIFIER_IMT_ROOT, + sessionJson = SESSION_JSON + ) + + val state = assertNotNull(db.getRoundState(ROUND_ID)) + assertEquals(ROUND_ID, state.roundId) + assertEquals(JniRoundPhase.INITIALIZED.value, state.phase) + assertEquals(JniRoundPhase.INITIALIZED, state.roundPhase) + assertEquals(SNAPSHOT_HEIGHT, state.snapshotHeight) + assertNull(state.hotkeyAddress) + assertNull(state.delegatedWeight) + assertFalse(state.proofGenerated) + + val rounds = db.listRounds() + assertEquals(1, rounds.size) + val round = rounds.single() + assertEquals(ROUND_ID, round.roundId) + assertEquals(JniRoundPhase.INITIALIZED.value, round.phase) + assertEquals(JniRoundPhase.INITIALIZED, round.roundPhase) + assertEquals(SNAPSHOT_HEIGHT, round.snapshotHeight) + + assertEquals(emptyList(), db.getVotes(ROUND_ID).asList()) + + db.clearRound(ROUND_ID) + assertNull(db.getRoundState(ROUND_ID)) + } finally { + db.close() + } + } + + @Test + fun voting_db_keeps_wallet_state_isolated() = + runTest { + val dbPath = newDbPath() + val firstWallet = VotingRustBackend.new().openVotingDb(dbPath, WALLET_ID) + val secondWallet = VotingRustBackend.new().openVotingDb(dbPath, OTHER_WALLET_ID) + try { + firstWallet.initRound( + roundId = ROUND_ID, + snapshotHeight = SNAPSHOT_HEIGHT, + eaPK = EA_PK, + ncRoot = NC_ROOT, + nullifierIMTRoot = NULLIFIER_IMT_ROOT, + sessionJson = null + ) + + assertNotNull(firstWallet.getRoundState(ROUND_ID)) + assertNull(secondWallet.getRoundState(ROUND_ID)) + } finally { + firstWallet.close() + secondWallet.close() + } + } + + @Test + fun voting_db_rejects_malformed_inputs_and_closed_handle() = + runTest { + val db = VotingRustBackend.new().openVotingDb(newDbPath(), WALLET_ID) + + assertFailsWith { + db.initRound( + roundId = ROUND_ID, + snapshotHeight = -1, + eaPK = EA_PK, + ncRoot = NC_ROOT, + nullifierIMTRoot = NULLIFIER_IMT_ROOT, + sessionJson = null + ) + } + assertFailsWith { + db.initRound( + roundId = ROUND_ID, + snapshotHeight = SNAPSHOT_HEIGHT, + eaPK = SHORT_FIELD, + ncRoot = NC_ROOT, + nullifierIMTRoot = NULLIFIER_IMT_ROOT, + sessionJson = null + ) + } + + db.close() + db.close() + assertFailsWith { + db.getRoundState(ROUND_ID) + } + } + + private fun newDbPath() = + createTempDirectory("voting-db-").resolve("voting.db").toFile().absolutePath } diff --git a/backend-lib/src/main/java/cash/z/ecc/android/sdk/internal/jni/VotingRustBackend.kt b/backend-lib/src/main/java/cash/z/ecc/android/sdk/internal/jni/VotingRustBackend.kt index 5e5ad2dd9..95c1fc521 100644 --- a/backend-lib/src/main/java/cash/z/ecc/android/sdk/internal/jni/VotingRustBackend.kt +++ b/backend-lib/src/main/java/cash/z/ecc/android/sdk/internal/jni/VotingRustBackend.kt @@ -1,5 +1,16 @@ package cash.z.ecc.android.sdk.internal.jni +import androidx.annotation.Keep +import cash.z.ecc.android.sdk.internal.model.voting.JniRoundState +import cash.z.ecc.android.sdk.internal.model.voting.JniRoundSummary +import cash.z.ecc.android.sdk.internal.model.voting.JniVoteRecord +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.sync.Mutex +import kotlinx.coroutines.sync.withLock +import kotlinx.coroutines.withContext + +@Keep +@Suppress("TooManyFunctions", "LongParameterList") class VotingRustBackend private constructor() { @Throws(RuntimeException::class) fun computeShareNullifier( @@ -8,6 +19,88 @@ class VotingRustBackend private constructor() { blind: ByteArray ): ByteArray = computeShareNullifierNative(voteCommitment, shareIndex, blind) + suspend fun openVotingDb(dbPath: String, walletId: String): VotingDb = + withContext(Dispatchers.IO) { + openVotingDbNative(dbPath, walletId).let { dbHandle -> + check(dbHandle != 0L) { + "openVotingDb failed for dbPath=$dbPath" + } + VotingDb(dbHandle) + } + } + + @Suppress("TooManyFunctions", "LongParameterList") + class VotingDb internal constructor( + private var dbHandle: Long? + ) { + private val accessMutex = Mutex() + + suspend fun close() { + accessMutex.withLock { + dbHandle?.let { handle -> + withContext(Dispatchers.IO) { + closeVotingDbNative(handle) + } + dbHandle = null + } + } + } + + @Throws(RuntimeException::class) + suspend fun initRound( + roundId: String, + snapshotHeight: Long, + eaPK: ByteArray, + ncRoot: ByteArray, + nullifierIMTRoot: ByteArray, + sessionJson: String? + ) = withHandle { handle -> + initRoundNative( + handle, + roundId, + snapshotHeight, + eaPK, + ncRoot, + nullifierIMTRoot, + sessionJson + ) + } + + @Throws(RuntimeException::class) + suspend fun getRoundState(roundId: String): JniRoundState? = + withHandle { handle -> getRoundStateNative(handle, roundId) } + + @Throws(RuntimeException::class) + suspend fun listRounds(): Array = + withHandle { handle -> listRoundsNative(handle) } + + @Throws(RuntimeException::class) + suspend fun getVotes(roundId: String): Array = + withHandle { handle -> getVotesNative(handle, roundId) } + + @Throws(RuntimeException::class) + suspend fun clearRound(roundId: String) = + withHandle { handle -> clearRoundNative(handle, roundId) } + + @Throws(RuntimeException::class) + suspend fun deleteSkippedBundles( + roundId: String, + keepCount: Int + ): Long = + withHandle { handle -> deleteSkippedBundlesNative(handle, roundId, keepCount) } + + private suspend fun withHandle(block: (Long) -> T): T = + accessMutex.withLock { + val handle = + checkNotNull(dbHandle) { + "Voting DB handle is closed" + } + withContext(Dispatchers.IO) { + block(handle) + } + } + } + companion object { suspend fun new(): VotingRustBackend { RustBackend.loadLibrary() @@ -22,5 +115,49 @@ class VotingRustBackend private constructor() { shareIndex: Int, blind: ByteArray ): ByteArray + + @JvmStatic + @Throws(RuntimeException::class) + private external fun openVotingDbNative(dbPath: String, walletId: String): Long + + @JvmStatic + @Throws(RuntimeException::class) + private external fun closeVotingDbNative(dbHandle: Long) + + @JvmStatic + @Throws(RuntimeException::class) + private external fun initRoundNative( + dbHandle: Long, + roundId: String, + snapshotHeight: Long, + eaPK: ByteArray, + ncRoot: ByteArray, + nullifierIMTRoot: ByteArray, + sessionJson: String? + ) + + @JvmStatic + @Throws(RuntimeException::class) + private external fun getRoundStateNative(dbHandle: Long, roundId: String): JniRoundState? + + @JvmStatic + @Throws(RuntimeException::class) + private external fun listRoundsNative(dbHandle: Long): Array + + @JvmStatic + @Throws(RuntimeException::class) + private external fun getVotesNative(dbHandle: Long, roundId: String): Array + + @JvmStatic + @Throws(RuntimeException::class) + private external fun clearRoundNative(dbHandle: Long, roundId: String) + + @JvmStatic + @Throws(RuntimeException::class) + private external fun deleteSkippedBundlesNative( + dbHandle: Long, + roundId: String, + keepCount: Int + ): Long } } diff --git a/backend-lib/src/main/java/cash/z/ecc/android/sdk/internal/model/voting/JniVotingModels.kt b/backend-lib/src/main/java/cash/z/ecc/android/sdk/internal/model/voting/JniVotingModels.kt new file mode 100644 index 000000000..cf017d986 --- /dev/null +++ b/backend-lib/src/main/java/cash/z/ecc/android/sdk/internal/model/voting/JniVotingModels.kt @@ -0,0 +1,57 @@ +package cash.z.ecc.android.sdk.internal.model.voting + +import androidx.annotation.Keep + +// Must match PHASE_* constants in backend-lib/src/main/rust/voting/helpers.rs. +internal const val JNI_ROUND_PHASE_INITIALIZED = 0 +internal const val JNI_ROUND_PHASE_HOTKEY_GENERATED = 1 +internal const val JNI_ROUND_PHASE_DELEGATION_CONSTRUCTED = 2 +internal const val JNI_ROUND_PHASE_DELEGATION_PROVED = 3 +internal const val JNI_ROUND_PHASE_VOTE_READY = 4 + +@Keep +data class JniRoundState( + val roundId: String, + val phase: Int, + val snapshotHeight: Long, + val hotkeyAddress: String?, + val delegatedWeight: Long?, + val proofGenerated: Boolean +) { + val roundPhase = JniRoundPhase.fromInt(phase) +} + +@Keep +enum class JniRoundPhase( + val value: Int +) { + INITIALIZED(JNI_ROUND_PHASE_INITIALIZED), + HOTKEY_GENERATED(JNI_ROUND_PHASE_HOTKEY_GENERATED), + DELEGATION_CONSTRUCTED(JNI_ROUND_PHASE_DELEGATION_CONSTRUCTED), + DELEGATION_PROVED(JNI_ROUND_PHASE_DELEGATION_PROVED), + VOTE_READY(JNI_ROUND_PHASE_VOTE_READY); + + companion object { + fun fromInt(value: Int) = + entries.firstOrNull { it.value == value } + ?: error("Unknown round phase: $value") + } +} + +@Keep +data class JniRoundSummary( + val roundId: String, + val phase: Int, + val snapshotHeight: Long, + val createdAt: Long +) { + val roundPhase = JniRoundPhase.fromInt(phase) +} + +@Keep +data class JniVoteRecord( + val proposalId: Int, + val bundleIndex: Int, + val choice: Int, + val submitted: Boolean +) diff --git a/backend-lib/src/main/rust/voting.rs b/backend-lib/src/main/rust/voting.rs index ad798e174..bbea42485 100644 --- a/backend-lib/src/main/rust/voting.rs +++ b/backend-lib/src/main/rust/voting.rs @@ -1,68 +1,28 @@ //! JNI bindings for the zcash_voting crate. -use std::ptr; - use anyhow::anyhow; use jni::{ JNIEnv, - objects::{JByteArray, JClass}, - sys::{jbyteArray, jint}, + objects::{JByteArray, JClass, JObject, JString, JValue}, + sys::{jboolean, jbyteArray, jint, jlong, jobject, jobjectArray}, +}; +use std::{ + collections::HashMap, + sync::{ + Arc, Mutex, OnceLock, + atomic::{AtomicI64, Ordering}, + }, }; use zcash_voting as voting; -use crate::utils::{self, catch_unwind, exception::unwrap_exc_or}; - -const VOTE_COMMITMENT_BYTES: usize = 32; -const BLIND_BYTES: usize = 32; -const SHARE_NULLIFIER_BYTES: usize = 32; - -/// Compute the share reveal nullifier from client-known inputs. -/// -/// Returns the 32-byte nullifier, or throws a RuntimeException and returns null -/// on malformed inputs. -#[unsafe(no_mangle)] -pub extern "C" fn Java_cash_z_ecc_android_sdk_internal_jni_VotingRustBackend_computeShareNullifierNative< - 'local, ->( - mut env: JNIEnv<'local>, - _: JClass<'local>, - vote_commitment: JByteArray<'local>, - share_index: jint, - blind: JByteArray<'local>, -) -> jbyteArray { - let res = catch_unwind(&mut env, |env| { - let share_index = - u32::try_from(share_index).map_err(|_| anyhow!("shareIndex must be non-negative"))?; - let vote_commitment = - java_fixed_bytes::(env, &vote_commitment, "voteCommitment")?; - let blind = java_fixed_bytes::(env, &blind, "blind")?; +use voting::storage::{RoundPhase, RoundState, RoundSummary, VoteRecord, VotingDb}; - let nullifier = - voting::share_tracking::compute_share_nullifier(&vote_commitment, share_index, &blind) - .map_err(|e| anyhow!("compute_share_nullifier failed: {}", e))?; - let nullifier_len = nullifier.len(); - let nullifier: [u8; SHARE_NULLIFIER_BYTES] = nullifier.try_into().map_err(|_| { - anyhow!( - "shareNullifier must be exactly {} bytes, got {}", - SHARE_NULLIFIER_BYTES, - nullifier_len - ) - })?; - - Ok(utils::rust_bytes_to_java(env, &nullifier)?.into_raw()) - }); - unwrap_exc_or(&mut env, res, ptr::null_mut()) -} - -fn java_fixed_bytes( - env: &JNIEnv<'_>, - array: &JByteArray<'_>, - field: &str, -) -> anyhow::Result<[u8; N]> { - let bytes = utils::java_bytes_to_rust(env, array)?; - let len = bytes.len(); +use crate::utils::{ + catch_unwind, exception::unwrap_exc_or, java_nullable_string_to_rust, java_string_to_rust, + rust_vec_to_java, +}; - bytes - .try_into() - .map_err(|_| anyhow!("{field} must be exactly {N} bytes, got {len}")) -} +mod db; +mod helpers; +mod rounds; +mod share_tracking; diff --git a/backend-lib/src/main/rust/voting/db.rs b/backend-lib/src/main/rust/voting/db.rs new file mode 100644 index 000000000..670db4747 --- /dev/null +++ b/backend-lib/src/main/rust/voting/db.rs @@ -0,0 +1,78 @@ +use super::*; + +static NEXT_DB_HANDLE: AtomicI64 = AtomicI64::new(1); +static DB_REGISTRY: OnceLock>>> = OnceLock::new(); + +fn registry() -> &'static Mutex>> { + DB_REGISTRY.get_or_init(|| Mutex::new(HashMap::new())) +} + +fn next_handle() -> anyhow::Result { + NEXT_DB_HANDLE + .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |id| { + id.checked_add(1).filter(|next| *next > 0) + }) + .map_err(|_| anyhow!("voting DB handle space exhausted")) +} + +pub(super) fn db_from_handle(handle: jlong) -> anyhow::Result> { + if handle <= 0 { + return Err(anyhow!("Voting DB handle must be positive, got {handle}")); + } + + registry() + .lock() + .map_err(|_| anyhow!("voting DB registry mutex poisoned"))? + .get(&handle) + .cloned() + .ok_or_else(|| anyhow!("Voting DB handle is closed or unknown: {handle}")) +} + +#[unsafe(no_mangle)] +pub extern "C" fn Java_cash_z_ecc_android_sdk_internal_jni_VotingRustBackend_openVotingDbNative< + 'local, +>( + mut env: JNIEnv<'local>, + _: JClass<'local>, + db_path: JString<'local>, + wallet_id: JString<'local>, +) -> jlong { + let res = catch_unwind(&mut env, |env| { + let path = java_string_to_rust(env, &db_path)?; + let wallet_id = java_string_to_rust(env, &wallet_id)?; + if wallet_id.is_empty() { + return Err(anyhow!("walletId must not be empty")); + } + + let db = VotingDb::open(&path).map_err(|e| anyhow!("VotingDb::open failed: {}", e))?; + db.set_wallet_id(&wallet_id); + let handle = next_handle()?; + registry() + .lock() + .map_err(|_| anyhow!("voting DB registry mutex poisoned"))? + .insert(handle, Arc::new(db)); + + Ok(handle) + }); + unwrap_exc_or(&mut env, res, 0) +} + +#[unsafe(no_mangle)] +pub extern "C" fn Java_cash_z_ecc_android_sdk_internal_jni_VotingRustBackend_closeVotingDbNative< + 'local, +>( + mut env: JNIEnv<'local>, + _: JClass<'local>, + db_handle: jlong, +) { + let res = catch_unwind(&mut env, |_| { + if db_handle > 0 { + registry() + .lock() + .map_err(|_| anyhow!("voting DB registry mutex poisoned"))? + .remove(&db_handle); + } + Ok(()) + }); + unwrap_exc_or(&mut env, res, ()) +} diff --git a/backend-lib/src/main/rust/voting/helpers.rs b/backend-lib/src/main/rust/voting/helpers.rs new file mode 100644 index 000000000..7960c98a9 --- /dev/null +++ b/backend-lib/src/main/rust/voting/helpers.rs @@ -0,0 +1,215 @@ +use super::*; + +// Must match JNI_ROUND_PHASE_* constants in JniVotingModels.kt. +const PHASE_INITIALIZED: u32 = 0; +const PHASE_HOTKEY_GENERATED: u32 = 1; +const PHASE_DELEGATION_CONSTRUCTED: u32 = 2; +const PHASE_DELEGATION_PROVED: u32 = 3; +const PHASE_VOTE_READY: u32 = 4; + +const JNI_ROUND_SUMMARY: &str = "cash/z/ecc/android/sdk/internal/model/voting/JniRoundSummary"; +const JNI_VOTE_RECORD: &str = "cash/z/ecc/android/sdk/internal/model/voting/JniVoteRecord"; + +pub(super) const PROTOCOL_FIELD_BYTES: usize = 32; +pub(super) const VOTE_COMMITMENT_BYTES: usize = PROTOCOL_FIELD_BYTES; +pub(super) const BLIND_BYTES: usize = PROTOCOL_FIELD_BYTES; +pub(super) const SHARE_NULLIFIER_BYTES: usize = 32; + +struct JniRoundSummaryPayload { + round_id: String, + phase: jint, + snapshot_height: jlong, + created_at: jlong, +} + +struct JniVoteRecordPayload { + proposal_id: jint, + bundle_index: jint, + choice: jint, + submitted: bool, +} + +pub(super) fn jint_to_u32(value: jint, field: &str) -> anyhow::Result { + u32::try_from(value).map_err(|_| anyhow!("{field} must be non-negative, got {value}")) +} + +pub(super) fn jlong_to_u64(value: jlong, field: &str) -> anyhow::Result { + u64::try_from(value).map_err(|_| anyhow!("{field} must be non-negative, got {value}")) +} + +fn u32_to_jint(value: u32, field: &str) -> anyhow::Result { + jint::try_from(value).map_err(|_| anyhow!("{field} exceeds signed Int range: {value}")) +} + +fn u64_to_jlong(value: u64, field: &str) -> anyhow::Result { + jlong::try_from(value).map_err(|_| anyhow!("{field} exceeds signed Long range: {value}")) +} + +pub(super) fn require_len(bytes: Vec, field: &str, expected: usize) -> anyhow::Result> { + if bytes.len() == expected { + Ok(bytes) + } else { + Err(anyhow!( + "{field} must be exactly {expected} bytes, got {}", + bytes.len() + )) + } +} + +pub(super) fn java_bytes( + env: &mut JNIEnv<'_>, + array: &JByteArray<'_>, + field: &str, +) -> anyhow::Result> { + env.convert_byte_array(array) + .map_err(|e| anyhow!("{field}: failed to read byte array: {e}")) +} + +pub(super) fn java_bytes_exact( + env: &mut JNIEnv<'_>, + array: &JByteArray<'_>, + field: &str, + expected: usize, +) -> anyhow::Result> { + require_len(java_bytes(env, array, field)?, field, expected) +} + +pub(super) fn java_fixed_bytes( + env: &mut JNIEnv<'_>, + array: &JByteArray<'_>, + field: &str, +) -> anyhow::Result<[u8; N]> { + fixed_bytes(java_bytes(env, array, field)?, field) +} + +pub(super) fn fixed_bytes(bytes: Vec, field: &str) -> anyhow::Result<[u8; N]> { + let len = bytes.len(); + + bytes + .try_into() + .map_err(|_| anyhow!("{field} must be exactly {N} bytes, got {len}")) +} + +pub(super) fn round_phase_to_u32(phase: RoundPhase) -> u32 { + match phase { + RoundPhase::Initialized => PHASE_INITIALIZED, + RoundPhase::HotkeyGenerated => PHASE_HOTKEY_GENERATED, + RoundPhase::DelegationConstructed => PHASE_DELEGATION_CONSTRUCTED, + RoundPhase::DelegationProved => PHASE_DELEGATION_PROVED, + RoundPhase::VoteReady => PHASE_VOTE_READY, + } +} + +pub(super) fn make_jni_round_state<'local>( + env: &mut JNIEnv<'local>, + state: RoundState, +) -> anyhow::Result { + let phase = round_phase_to_u32(state.phase) as i32; + let class = env.find_class("cash/z/ecc/android/sdk/internal/model/voting/JniRoundState")?; + let round_id_obj: JObject<'local> = env.new_string(&state.round_id)?.into(); + let hotkey_obj: JObject<'local> = match &state.hotkey_address { + Some(a) => env.new_string(a)?.into(), + None => JObject::null(), + }; + let long_class = env.find_class("java/lang/Long")?; + let weight_obj: JObject<'local> = match state.delegated_weight { + Some(w) => env.new_object(&long_class, "(J)V", &[JValue::Long(w as i64)])?, + None => JObject::null(), + }; + let obj = env.new_object( + &class, + // Matches JniRoundState(roundId, phase, snapshotHeight, hotkeyAddress, + // delegatedWeight, proofGenerated). + "(Ljava/lang/String;IJLjava/lang/String;Ljava/lang/Long;Z)V", + &[ + JValue::Object(&round_id_obj), + JValue::Int(phase), + JValue::Long(state.snapshot_height as i64), + JValue::Object(&hotkey_obj), + JValue::Object(&weight_obj), + JValue::Bool(state.proof_generated as jboolean), + ], + )?; + Ok(obj.into_raw()) +} + +pub(super) fn make_jni_round_summaries( + env: &mut JNIEnv<'_>, + rounds: Vec, +) -> anyhow::Result { + let payloads = rounds + .into_iter() + .map(JniRoundSummaryPayload::try_from) + .collect::>>()?; + + Ok( + rust_vec_to_java(env, payloads, JNI_ROUND_SUMMARY, |env, round| { + let round_id_obj: JObject<'_> = env.new_string(round.round_id)?.into(); + env.new_object( + JNI_ROUND_SUMMARY, + // Matches JniRoundSummary(roundId, phase, snapshotHeight, createdAt). + "(Ljava/lang/String;IJJ)V", + &[ + JValue::Object(&round_id_obj), + JValue::Int(round.phase), + JValue::Long(round.snapshot_height), + JValue::Long(round.created_at), + ], + ) + })? + .into_raw(), + ) +} + +pub(super) fn make_jni_vote_records( + env: &mut JNIEnv<'_>, + votes: Vec, +) -> anyhow::Result { + let payloads = votes + .into_iter() + .map(JniVoteRecordPayload::try_from) + .collect::>>()?; + + Ok( + rust_vec_to_java(env, payloads, JNI_VOTE_RECORD, |env, vote| { + env.new_object( + JNI_VOTE_RECORD, + // Matches JniVoteRecord(proposalId, bundleIndex, choice, submitted). + "(IIIZ)V", + &[ + JValue::Int(vote.proposal_id), + JValue::Int(vote.bundle_index), + JValue::Int(vote.choice), + JValue::Bool(vote.submitted as jboolean), + ], + ) + })? + .into_raw(), + ) +} + +impl TryFrom for JniRoundSummaryPayload { + type Error = anyhow::Error; + + fn try_from(round: RoundSummary) -> anyhow::Result { + Ok(JniRoundSummaryPayload { + round_id: round.round_id, + phase: u32_to_jint(round_phase_to_u32(round.phase), "phase")?, + snapshot_height: u64_to_jlong(round.snapshot_height, "snapshot_height")?, + created_at: u64_to_jlong(round.created_at, "created_at")?, + }) + } +} + +impl TryFrom for JniVoteRecordPayload { + type Error = anyhow::Error; + + fn try_from(record: VoteRecord) -> anyhow::Result { + Ok(JniVoteRecordPayload { + proposal_id: u32_to_jint(record.proposal_id, "proposal_id")?, + bundle_index: u32_to_jint(record.bundle_index, "bundle_index")?, + choice: u32_to_jint(record.choice, "choice")?, + submitted: record.submitted, + }) + } +} diff --git a/backend-lib/src/main/rust/voting/rounds.rs b/backend-lib/src/main/rust/voting/rounds.rs new file mode 100644 index 000000000..2dce8f909 --- /dev/null +++ b/backend-lib/src/main/rust/voting/rounds.rs @@ -0,0 +1,144 @@ +use super::db::*; +use super::helpers::*; +use super::*; + +#[unsafe(no_mangle)] +pub extern "C" fn Java_cash_z_ecc_android_sdk_internal_jni_VotingRustBackend_initRoundNative< + 'local, +>( + mut env: JNIEnv<'local>, + _: JClass<'local>, + db_handle: jlong, + round_id: JString<'local>, + snapshot_height: jlong, + ea_pk: JByteArray<'local>, + nc_root: JByteArray<'local>, + nullifier_imt_root: JByteArray<'local>, + session_json: JString<'local>, +) { + let res = catch_unwind(&mut env, |env| { + let db = db_from_handle(db_handle)?; + let params = voting::types::VotingRoundParams { + vote_round_id: java_string_to_rust(env, &round_id)?, + snapshot_height: jlong_to_u64(snapshot_height, "snapshot_height")?, + ea_pk: java_bytes_exact(env, &ea_pk, "ea_pk", PROTOCOL_FIELD_BYTES)?, + nc_root: java_bytes_exact(env, &nc_root, "nc_root", PROTOCOL_FIELD_BYTES)?, + nullifier_imt_root: java_bytes_exact( + env, + &nullifier_imt_root, + "nullifier_imt_root", + PROTOCOL_FIELD_BYTES, + )?, + }; + let session = java_nullable_string_to_rust(env, &session_json)?; + db.init_round(¶ms, session.as_deref()) + .map_err(|e| anyhow!("init_round: {}", e))?; + Ok(()) + }); + unwrap_exc_or(&mut env, res, ()) +} + +#[unsafe(no_mangle)] +pub extern "C" fn Java_cash_z_ecc_android_sdk_internal_jni_VotingRustBackend_getRoundStateNative< + 'local, +>( + mut env: JNIEnv<'local>, + _: JClass<'local>, + db_handle: jlong, + round_id: JString<'local>, +) -> jobject { + let res = catch_unwind(&mut env, |env| { + let db = db_from_handle(db_handle)?; + let round_id = java_string_to_rust(env, &round_id)?; + if !db + .has_round(&round_id) + .map_err(|e| anyhow!("has_round: {}", e))? + { + Ok(JObject::null().into_raw()) + } else { + let state = db + .get_round_state(&round_id) + .map_err(|e| anyhow!("get_round_state: {}", e))?; + make_jni_round_state(env, state) + } + }); + unwrap_exc_or(&mut env, res, JObject::null().into_raw()) +} + +#[unsafe(no_mangle)] +pub extern "C" fn Java_cash_z_ecc_android_sdk_internal_jni_VotingRustBackend_listRoundsNative< + 'local, +>( + mut env: JNIEnv<'local>, + _: JClass<'local>, + db_handle: jlong, +) -> jobjectArray { + let res = catch_unwind(&mut env, |env| { + let db = db_from_handle(db_handle)?; + let rounds = db + .list_rounds() + .map_err(|e| anyhow!("list_rounds: {}", e))?; + make_jni_round_summaries(env, rounds) + }); + unwrap_exc_or(&mut env, res, std::ptr::null_mut()) +} + +#[unsafe(no_mangle)] +pub extern "C" fn Java_cash_z_ecc_android_sdk_internal_jni_VotingRustBackend_getVotesNative< + 'local, +>( + mut env: JNIEnv<'local>, + _: JClass<'local>, + db_handle: jlong, + round_id: JString<'local>, +) -> jobjectArray { + let res = catch_unwind(&mut env, |env| { + let db = db_from_handle(db_handle)?; + let votes = db + .get_votes(&java_string_to_rust(env, &round_id)?) + .map_err(|e| anyhow!("get_votes: {}", e))?; + make_jni_vote_records(env, votes) + }); + unwrap_exc_or(&mut env, res, std::ptr::null_mut()) +} + +#[unsafe(no_mangle)] +pub extern "C" fn Java_cash_z_ecc_android_sdk_internal_jni_VotingRustBackend_clearRoundNative< + 'local, +>( + mut env: JNIEnv<'local>, + _: JClass<'local>, + db_handle: jlong, + round_id: JString<'local>, +) { + let res = catch_unwind(&mut env, |env| { + let db = db_from_handle(db_handle)?; + db.clear_round(&java_string_to_rust(env, &round_id)?) + .map_err(|e| anyhow!("clear_round: {}", e))?; + Ok(()) + }); + unwrap_exc_or(&mut env, res, ()) +} + +#[unsafe(no_mangle)] +pub extern "C" fn Java_cash_z_ecc_android_sdk_internal_jni_VotingRustBackend_deleteSkippedBundlesNative< + 'local, +>( + mut env: JNIEnv<'local>, + _: JClass<'local>, + db_handle: jlong, + round_id: JString<'local>, + keep_count: jint, +) -> jlong { + let res = catch_unwind(&mut env, |env| { + let db = db_from_handle(db_handle)?; + let deleted_rows = db + .delete_skipped_bundles( + &java_string_to_rust(env, &round_id)?, + jint_to_u32(keep_count, "keep_count")?, + ) + .map_err(|e| anyhow!("delete_skipped_bundles: {}", e))?; + Ok(deleted_rows as jlong) + }); + unwrap_exc_or(&mut env, res, -1) +} diff --git a/backend-lib/src/main/rust/voting/share_tracking.rs b/backend-lib/src/main/rust/voting/share_tracking.rs new file mode 100644 index 000000000..5d379e0fa --- /dev/null +++ b/backend-lib/src/main/rust/voting/share_tracking.rs @@ -0,0 +1,29 @@ +use super::helpers::*; +use super::*; + +/// Compute the share reveal nullifier from client-known inputs. +/// +/// Returns the 32-byte nullifier, or throws a RuntimeException and returns null +/// on malformed inputs. +#[unsafe(no_mangle)] +pub extern "C" fn Java_cash_z_ecc_android_sdk_internal_jni_VotingRustBackend_computeShareNullifierNative< + 'local, +>( + mut env: JNIEnv<'local>, + _: JClass<'local>, + vote_commitment: JByteArray<'local>, + share_index: jint, + blind: JByteArray<'local>, +) -> jbyteArray { + let res = catch_unwind(&mut env, |env| { + let nullifier = voting::share_tracking::compute_share_nullifier( + &java_fixed_bytes::(env, &vote_commitment, "voteCommitment")?, + jint_to_u32(share_index, "share_index")?, + &java_fixed_bytes::(env, &blind, "blind")?, + ) + .map_err(|e| anyhow!("compute_share_nullifier: {}", e))?; + let nullifier = fixed_bytes::(nullifier, "shareNullifier")?; + Ok(env.byte_array_from_slice(&nullifier)?.into_raw()) + }); + unwrap_exc_or(&mut env, res, std::ptr::null_mut()) +} diff --git a/sdk-lib/src/main/java/cash/z/ecc/android/sdk/internal/TypesafeVotingBackend.kt b/sdk-lib/src/main/java/cash/z/ecc/android/sdk/internal/TypesafeVotingBackend.kt new file mode 100644 index 000000000..e983b4aad --- /dev/null +++ b/sdk-lib/src/main/java/cash/z/ecc/android/sdk/internal/TypesafeVotingBackend.kt @@ -0,0 +1,37 @@ +package cash.z.ecc.android.sdk.internal + +import cash.z.ecc.android.sdk.internal.model.voting.JniRoundState +import cash.z.ecc.android.sdk.internal.model.voting.JniRoundSummary +import cash.z.ecc.android.sdk.internal.model.voting.JniVoteRecord + +@Suppress("TooManyFunctions", "LongParameterList") +interface TypesafeVotingBackend { + suspend fun openVotingDb(dbPath: String, walletId: String): TypesafeVotingDb +} + +@Suppress("TooManyFunctions", "LongParameterList") +interface TypesafeVotingDb { + suspend fun close() + + suspend fun initRound( + roundId: String, + snapshotHeight: Long, + eaPK: ByteArray, + ncRoot: ByteArray, + nullifierIMTRoot: ByteArray, + sessionJson: String? + ) + + suspend fun getRoundState(roundId: String): JniRoundState? + + suspend fun listRounds(): List + + suspend fun getVotes(roundId: String): List + + suspend fun clearRound(roundId: String) + + suspend fun deleteSkippedBundles( + roundId: String, + keepCount: Int + ): Long +} diff --git a/sdk-lib/src/main/java/cash/z/ecc/android/sdk/internal/TypesafeVotingBackendImpl.kt b/sdk-lib/src/main/java/cash/z/ecc/android/sdk/internal/TypesafeVotingBackendImpl.kt new file mode 100644 index 000000000..3985ccb38 --- /dev/null +++ b/sdk-lib/src/main/java/cash/z/ecc/android/sdk/internal/TypesafeVotingBackendImpl.kt @@ -0,0 +1,59 @@ +package cash.z.ecc.android.sdk.internal + +import cash.z.ecc.android.sdk.internal.jni.VotingRustBackend +import cash.z.ecc.android.sdk.internal.model.voting.JniRoundState +import cash.z.ecc.android.sdk.internal.model.voting.JniRoundSummary +import cash.z.ecc.android.sdk.internal.model.voting.JniVoteRecord + +@Suppress("TooManyFunctions", "LongParameterList") +class TypesafeVotingBackendImpl : TypesafeVotingBackend { + private val rustBackendLazy = + SuspendingLazy { + VotingRustBackend.new() + } + + override suspend fun openVotingDb(dbPath: String, walletId: String): TypesafeVotingDb = + TypesafeVotingDbImpl(rustBackend().openVotingDb(dbPath, walletId)) + + private suspend fun rustBackend() = rustBackendLazy.getInstance(Unit) +} + +@Suppress("TooManyFunctions", "LongParameterList") +private class TypesafeVotingDbImpl( + private val votingDb: VotingRustBackend.VotingDb +) : TypesafeVotingDb { + override suspend fun close() = votingDb.close() + + override suspend fun initRound( + roundId: String, + snapshotHeight: Long, + eaPK: ByteArray, + ncRoot: ByteArray, + nullifierIMTRoot: ByteArray, + sessionJson: String? + ) = votingDb.initRound( + roundId, + snapshotHeight, + eaPK, + ncRoot, + nullifierIMTRoot, + sessionJson + ) + + override suspend fun getRoundState(roundId: String): JniRoundState? = + votingDb.getRoundState(roundId) + + override suspend fun listRounds(): List = + votingDb.listRounds().asList() + + override suspend fun getVotes(roundId: String): List = + votingDb.getVotes(roundId).asList() + + override suspend fun clearRound(roundId: String) = + votingDb.clearRound(roundId) + + override suspend fun deleteSkippedBundles( + roundId: String, + keepCount: Int + ): Long = votingDb.deleteSkippedBundles(roundId, keepCount) +}