diff --git a/CHANGELOG.md b/CHANGELOG.md index f1c5213d2..02c75f605 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,8 +10,13 @@ and this library adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `Synchronizer.fullyScannedHeight` and `Synchronizer.getTreeState` accessors for snapshot-height consumers. +### Changed +- `String.fromHex` now rejects odd-length and non-hex input instead of silently coercing malformed + strings. + ### Internal -- Added the Rust `zcash_voting` dependency foundation for future shielded voting backend work. +- Added internal `VotingRustBackend` / `TypesafeVotingBackend` plumbing for future shielded voting backend work. +- Pinned `orchard` to `=0.13.1` with `unstable-voting-circuits` to match `zcash_voting` / `voting-circuits` requirements. ## [2.5.0] - 2026-05-01 diff --git a/backend-lib/Cargo.lock b/backend-lib/Cargo.lock index 7e4cbefd7..749aeaeb5 100644 --- a/backend-lib/Cargo.lock +++ b/backend-lib/Cargo.lock @@ -6698,6 +6698,7 @@ dependencies = [ "dlopen2", "eip681", "fs-mistrust", + "hex", "http", "http-body-util", "jni", @@ -6715,6 +6716,8 @@ dependencies = [ "rust_decimal", "sapling-crypto", "secrecy", + "serde", + "serde_json", "tonic", "tor-rtcompat", "tracing", diff --git a/backend-lib/Cargo.toml b/backend-lib/Cargo.toml index 4d35ce40b..f29bcc3c2 100644 --- a/backend-lib/Cargo.toml +++ b/backend-lib/Cargo.toml @@ -54,6 +54,9 @@ anyhow = "1" jni = { version = "0.21", default-features = false } uuid = "1" bitflags = "2" +serde = { version = "1", features = ["derive"] } +serde_json = "1" +hex = "0.4" # lightwalletd tonic = "0.14" @@ -81,8 +84,8 @@ rust-analyzer = "0.0.1" # # `zcash_voting` provides client-side primitives for shielded voting. It depends # on Orchard's unstable voting-circuits APIs, so the direct Orchard dependency is -# pinned to the same version and enables the same feature. Remove the exact pin -# once these circuit APIs are available through a stable Orchard feature. +# pinned to the same version and enables the same feature. Revisit the exact pin +# if the upstream voting crates relax their Orchard requirement. # The Android JNI boundary for voting code is `src/main/rust/voting.rs`. # Its share-nullifier entry point forwards to # `zcash_voting::share_tracking::compute_share_nullifier`: diff --git a/backend-lib/src/androidTest/java/cash/z/ecc/android/sdk/internal/jni/VotingRustBackendTest.kt b/backend-lib/src/androidTest/java/cash/z/ecc/android/sdk/internal/jni/VotingRustBackendTest.kt index 5ae97617d..eab6ef343 100644 --- a/backend-lib/src/androidTest/java/cash/z/ecc/android/sdk/internal/jni/VotingRustBackendTest.kt +++ b/backend-lib/src/androidTest/java/cash/z/ecc/android/sdk/internal/jni/VotingRustBackendTest.kt @@ -2,6 +2,8 @@ package cash.z.ecc.android.sdk.internal.jni import cash.z.ecc.android.sdk.internal.model.voting.JniRoundPhase import kotlinx.coroutines.test.runTest +import org.json.JSONArray +import org.json.JSONObject import org.junit.Test import kotlin.io.path.createTempDirectory import kotlin.test.assertContentEquals @@ -10,6 +12,7 @@ import kotlin.test.assertFailsWith import kotlin.test.assertFalse import kotlin.test.assertNotNull import kotlin.test.assertNull +import kotlin.test.assertTrue @OptIn(ExperimentalStdlibApi::class) @Suppress("MagicNumber") @@ -18,6 +21,7 @@ class VotingRustBackendTest { private const val FIELD_BYTES = 32 private const val SHARE_INDEX = 5 private const val OUT_OF_RANGE_SHARE_INDEX = 16 + private const val DIVERSIFIER_BYTES = 11 private val VOTE_COMMITMENT = ByteArray(FIELD_BYTES) { 1 } private val BLIND = ByteArray(FIELD_BYTES) { 2 } private val SHORT_FIELD = ByteArray(FIELD_BYTES - 1) @@ -29,9 +33,25 @@ class VotingRustBackendTest { private const val ROUND_ID = "round-1" private const val SNAPSHOT_HEIGHT = 123_456L private const val SESSION_JSON = "{\"round\":\"one\"}" + private const val TESTNET_NETWORK_ID = JNI_VOTING_NETWORK_ID_TESTNET + private const val ACCOUNT_INDEX = 0 + private const val ADDRESS_INDEX = 1 + private const val MAINNET_NETWORK_ID = JNI_VOTING_NETWORK_ID_MAINNET + private const val SECOND_ROUND_ID = "round-2" + private const val PCZT_ROUND_ID = + "0101010101010101010101010101010101010101010101010101010101010101" + private const val ROUND_NAME = "Test Round" + private const val NOTE_VALUE = 13_000_000L + private const val PCZT_NOTE_VALUE = 15_000_000L + private const val LARGE_BUNDLE_WEIGHT = 62_500_000L + private const val SMALL_BUNDLE_WEIGHT = 12_500_000L + private const val TWO_BUNDLE_ELIGIBLE_WEIGHT = 75_000_000L private val EA_PK = ByteArray(FIELD_BYTES) { 3 } private val NC_ROOT = ByteArray(FIELD_BYTES) { 4 } private val NULLIFIER_IMT_ROOT = ByteArray(FIELD_BYTES) { 5 } + private val HOTKEY_SEED = ByteArray(64) { 0x42 } + private val OTHER_HOTKEY_SEED = ByteArray(64) { 0x43 } + private val SEED_FINGERPRINT = ByteArray(FIELD_BYTES) { 6 } } @Test @@ -61,6 +81,12 @@ class VotingRustBackendTest { } } + @Test + fun warm_proving_caches_smoke() = + runTest { + VotingRustBackend.new().warmProvingCaches() + } + @Test fun voting_db_round_state_round_trips() = runTest { @@ -127,6 +153,41 @@ class VotingRustBackendTest { } } + @Test + fun list_rounds_returns_all_rounds_for_current_wallet_only() = + runTest { + val dbPath = newDbPath() + val firstWallet = VotingRustBackend.new().openVotingDb(dbPath, WALLET_ID) + val secondWallet = VotingRustBackend.new().openVotingDb(dbPath, OTHER_WALLET_ID) + try { + firstWallet.initRound( + roundId = ROUND_ID, + snapshotHeight = SNAPSHOT_HEIGHT, + eaPK = EA_PK, + ncRoot = NC_ROOT, + nullifierIMTRoot = NULLIFIER_IMT_ROOT, + sessionJson = null + ) + firstWallet.initRound( + roundId = SECOND_ROUND_ID, + snapshotHeight = SNAPSHOT_HEIGHT, + eaPK = EA_PK, + ncRoot = NC_ROOT, + nullifierIMTRoot = NULLIFIER_IMT_ROOT, + sessionJson = null + ) + + val firstWalletRounds = firstWallet.listRounds().map { it.roundId }.toSet() + val secondWalletRounds = secondWallet.listRounds() + + assertEquals(setOf(ROUND_ID, SECOND_ROUND_ID), firstWalletRounds) + assertEquals(0, secondWalletRounds.size) + } finally { + firstWallet.close() + secondWallet.close() + } + } + @Test fun voting_db_rejects_malformed_inputs_and_closed_handle() = runTest { @@ -160,6 +221,318 @@ class VotingRustBackendTest { } } + @Test + fun compute_bundle_setup_returns_exact_weights() = + runTest { + val setup = VotingRustBackend.new().computeBundleSetup(notesJson(noteCount = 6)) + + assertEquals(2, setup.bundleCount) + assertEquals(TWO_BUNDLE_ELIGIBLE_WEIGHT, setup.eligibleWeight) + assertEquals(listOf(LARGE_BUNDLE_WEIGHT, SMALL_BUNDLE_WEIGHT), setup.bundleWeights) + assertEquals(setup.eligibleWeight, setup.bundleWeights.sum()) + } + + @Test + fun compute_bundle_setup_rejects_unknown_note_scope() = + runTest { + val notesJson = + JSONArray() + .put(noteJson(value = NOTE_VALUE, position = 0, byteValue = 1, scope = 2)) + .toString() + + assertFailsWith { + VotingRustBackend.new().computeBundleSetup(notesJson) + } + } + + @Test + fun compute_bundle_setup_rejects_malformed_diversifier() = + runTest { + val notesJson = + JSONArray() + .put( + noteJson(value = NOTE_VALUE, position = 0, byteValue = 1) + .put("diversifier", repeatedHex(0, DIVERSIFIER_BYTES - 1)) + ).toString() + + assertFailsWith { + VotingRustBackend.new().computeBundleSetup(notesJson) + } + } + + @Test + fun setup_bundles_round_trips_bundle_count() = + runTest { + val db = VotingRustBackend.new().openVotingDb(newDbPath(), WALLET_ID) + try { + db.initRound( + roundId = ROUND_ID, + snapshotHeight = SNAPSHOT_HEIGHT, + eaPK = EA_PK, + ncRoot = NC_ROOT, + nullifierIMTRoot = NULLIFIER_IMT_ROOT, + sessionJson = null + ) + + val setup = db.setupBundles(ROUND_ID, notesJson(noteCount = 6)) + + assertEquals(2, setup.bundleCount) + assertEquals(TWO_BUNDLE_ELIGIBLE_WEIGHT, setup.eligibleWeight) + assertEquals(listOf(LARGE_BUNDLE_WEIGHT, SMALL_BUNDLE_WEIGHT), setup.bundleWeights) + assertEquals(setup.eligibleWeight, setup.bundleWeights.sum()) + assertEquals(2, db.getBundleCount(ROUND_ID)) + + val deletedRows = db.deleteSkippedBundles(ROUND_ID, keepCount = 1) + assertEquals(1L, deletedRows) + assertEquals(1, db.getBundleCount(ROUND_ID)) + } finally { + db.close() + } + } + + @Test + fun generate_hotkey_is_deterministic_and_rejects_short_seed() = + runTest { + val db = VotingRustBackend.new().openVotingDb(newDbPath(), WALLET_ID) + try { + db.initRound( + roundId = ROUND_ID, + snapshotHeight = SNAPSHOT_HEIGHT, + eaPK = EA_PK, + ncRoot = NC_ROOT, + nullifierIMTRoot = NULLIFIER_IMT_ROOT, + sessionJson = null + ) + + val first = db.generateHotkey(ROUND_ID, HOTKEY_SEED) + val second = db.generateHotkey(ROUND_ID, HOTKEY_SEED) + val other = db.generateHotkey(ROUND_ID, OTHER_HOTKEY_SEED) + + assertContentEquals(first.publicKey.value, second.publicKey.value) + assertEquals(first.address, second.address) + assertFalse(first.publicKey.value.contentEquals(other.publicKey.value)) + assertEquals(FIELD_BYTES, first.publicKey.value.size) + assertTrue(first.address.startsWith("sv1")) + assertEquals( + JniRoundPhase.HOTKEY_GENERATED, + assertNotNull(db.getRoundState(ROUND_ID)).roundPhase + ) + + assertFailsWith { + db.generateHotkey(ROUND_ID, SHORT_FIELD) + } + } finally { + db.close() + } + } + + @Test + fun build_governance_pczt_rejects_mismatched_bundle_inputs_and_seed() = + runTest { + val db = VotingRustBackend.new().openVotingDb(newDbPath(), WALLET_ID) + try { + val notesJson = notesJson(noteCount = 6, value = PCZT_NOTE_VALUE) + val mismatchedNotesJson = notesJson(noteCount = 1, value = PCZT_NOTE_VALUE) + val mismatchedSameIndexNotesJson = + notesJson(noteCount = 6, value = PCZT_NOTE_VALUE, positionOffset = 10) + val mismatchedSamePositionNotesJson = + notesJson(noteCount = 6, value = PCZT_NOTE_VALUE, ufvkString = "different") + val ufvk = deriveTestUfvk() + val mismatchedUfvk = deriveTestUfvk(seed = OTHER_HOTKEY_SEED) + db.initPcztRoundWithBundles(notesJson) + + assertFailsWith { + db.buildTestGovernancePcztJson(ufvk, mismatchedNotesJson) + } + assertFailsWith { + db.buildTestGovernancePcztJson(ufvk, mismatchedSameIndexNotesJson) + } + assertFailsWith { + db.buildTestGovernancePcztJson(ufvk, mismatchedSamePositionNotesJson) + } + assertFailsWith { + db.buildTestGovernancePcztJson(mismatchedUfvk, notesJson) + } + } finally { + db.close() + } + } + + @Test + fun build_governance_pczt_requires_hotkey_generated_phase() = + runTest { + val db = VotingRustBackend.new().openVotingDb(newDbPath(), WALLET_ID) + try { + val notesJson = notesJson(noteCount = 6, value = PCZT_NOTE_VALUE) + val ufvk = deriveTestUfvk() + db.initRound( + roundId = PCZT_ROUND_ID, + snapshotHeight = SNAPSHOT_HEIGHT, + eaPK = EA_PK, + ncRoot = NC_ROOT, + nullifierIMTRoot = NULLIFIER_IMT_ROOT, + sessionJson = null + ) + db.setupBundles(PCZT_ROUND_ID, notesJson) + + assertFailsWith { + db.buildTestGovernancePcztJson(ufvk, notesJson) + } + assertEquals( + JniRoundPhase.INITIALIZED, + assertNotNull(db.getRoundState(PCZT_ROUND_ID)).roundPhase + ) + } finally { + db.close() + } + } + + @Test + fun build_governance_pczt_returns_parseable_pczt_and_extractable_sighash() = + runTest { + val backend = VotingRustBackend.new() + val db = backend.openVotingDb(newDbPath(), WALLET_ID) + try { + val notesJson = notesJson(noteCount = 6, value = PCZT_NOTE_VALUE) + val ufvk = deriveTestUfvk() + db.initPcztRoundWithBundles(notesJson) + + val pcztJson = + JSONObject(db.buildTestGovernancePcztJson(ufvk, notesJson)) + val pcztBytes = pcztJson.getString("pczt_bytes").hexToByteArray() + val sighash = pcztJson.getString("pczt_sighash").hexToByteArray() + val extractedSighash = backend.extractPcztSighash(pcztBytes) + + assertTrue(pcztBytes.isNotEmpty()) + assertEquals(FIELD_BYTES, pcztJson.getString("rk").hexToByteArray().size) + assertEquals(FIELD_BYTES, sighash.size) + assertTrue(pcztJson.getInt("action_index") >= 0) + assertContentEquals(sighash, extractedSighash) + assertEquals( + JniRoundPhase.DELEGATION_CONSTRUCTED, + assertNotNull(db.getRoundState(PCZT_ROUND_ID)).roundPhase + ) + assertFailsWith { + db.generateHotkey(PCZT_ROUND_ID, HOTKEY_SEED) + } + assertFailsWith { + backend.extractSpendAuthSig(pcztBytes, pcztJson.getInt("action_index")) + } + } finally { + db.close() + } + } + + @Test + fun build_governance_pczt_accepts_mainnet_network_id() = + runTest { + val db = VotingRustBackend.new().openVotingDb(newDbPath(), WALLET_ID) + try { + val notesJson = notesJson(noteCount = 6, value = PCZT_NOTE_VALUE) + val ufvk = deriveTestUfvk(networkId = MAINNET_NETWORK_ID) + db.initPcztRoundWithBundles(notesJson) + + val pcztJson = + JSONObject( + db.buildTestGovernancePcztJson( + ufvk = ufvk, + notesJson = notesJson, + networkId = MAINNET_NETWORK_ID + ) + ) + + assertTrue(pcztJson.getString("pczt_bytes").hexToByteArray().isNotEmpty()) + } finally { + db.close() + } + } + private fun newDbPath() = createTempDirectory("voting-db-").resolve("voting.db").toFile().absolutePath + + private suspend fun deriveTestUfvk( + seed: ByteArray = HOTKEY_SEED, + networkId: Int = TESTNET_NETWORK_ID + ): String = + RustDerivationTool + .new() + .deriveUnifiedFullViewingKeys(seed, networkId, 1) + .first() + + private suspend fun VotingRustBackend.VotingDb.initPcztRoundWithBundles( + notesJson: String, + roundId: String = PCZT_ROUND_ID + ) { + initRound( + roundId = roundId, + snapshotHeight = SNAPSHOT_HEIGHT, + eaPK = EA_PK, + ncRoot = NC_ROOT, + nullifierIMTRoot = NULLIFIER_IMT_ROOT, + sessionJson = null + ) + setupBundles(roundId, notesJson) + generateHotkey(roundId, HOTKEY_SEED) + } + + private suspend fun VotingRustBackend.VotingDb.buildTestGovernancePcztJson( + ufvk: String, + notesJson: String, + walletSeed: ByteArray = HOTKEY_SEED, + networkId: Int = TESTNET_NETWORK_ID, + roundId: String = PCZT_ROUND_ID + ) = buildGovernancePcztJson( + roundId = roundId, + bundleIndex = 1, + ufvk = ufvk, + networkId = networkId, + accountIndex = ACCOUNT_INDEX, + notesJson = notesJson, + walletSeed = walletSeed, + seedFingerprint = SEED_FINGERPRINT, + roundName = ROUND_NAME, + addressIndex = ADDRESS_INDEX + ) + + private fun notesJson( + noteCount: Int, + value: Long = NOTE_VALUE, + positionOffset: Long = 0, + ufvkString: String = "" + ): String = + JSONArray() + .apply { + repeat(noteCount) { index -> + put( + noteJson( + value = value, + position = positionOffset + index.toLong(), + byteValue = index + 1, + ufvkString = ufvkString + ) + ) + } + }.toString() + + private fun noteJson( + value: Long, + position: Long, + byteValue: Int, + scope: Int = 0, + ufvkString: String = "" + ) = JSONObject() + .put("commitment", repeatedHex(byteValue)) + .put("nullifier", repeatedHex(byteValue + 1)) + .put("value", value) + .put("position", position) + .put("diversifier", repeatedHex(0, DIVERSIFIER_BYTES)) + .put("rho", repeatedHex(0)) + .put("rseed", repeatedHex(0)) + .put("scope", scope) + .put("ufvk_str", ufvkString) + + private fun repeatedHex( + byteValue: Int, + size: Int = FIELD_BYTES + ) = ByteArray(size) { byteValue.toByte() }.toHexString() } diff --git a/backend-lib/src/main/java/cash/z/ecc/android/sdk/internal/jni/JniConstants.kt b/backend-lib/src/main/java/cash/z/ecc/android/sdk/internal/jni/JniConstants.kt index 1158c5827..113532c21 100644 --- a/backend-lib/src/main/java/cash/z/ecc/android/sdk/internal/jni/JniConstants.kt +++ b/backend-lib/src/main/java/cash/z/ecc/android/sdk/internal/jni/JniConstants.kt @@ -22,3 +22,18 @@ const val JNI_METADATA_KEY_SK_SIZE = 32 * The number of bytes in a chain code. It's used e.g. in [JniMetadataKey.chainCode] */ const val JNI_METADATA_KEY_CHAIN_CODE_SIZE = 32 + +/** + * The number of bytes in a voting hotkey public key. It's used e.g. in [HotkeyPublicKey.value] + */ +const val JNI_HOTKEY_PUBLIC_KEY_BYTES_SIZE = 32 + +/** + * Voting JNI network id for testnet. Matches [cash.z.ecc.android.sdk.model.ZcashNetwork.ID_TESTNET]. + */ +const val JNI_VOTING_NETWORK_ID_TESTNET = 0 + +/** + * Voting JNI network id for mainnet. Matches [cash.z.ecc.android.sdk.model.ZcashNetwork.ID_MAINNET]. + */ +const val JNI_VOTING_NETWORK_ID_MAINNET = 1 diff --git a/backend-lib/src/main/java/cash/z/ecc/android/sdk/internal/jni/VotingRustBackend.kt b/backend-lib/src/main/java/cash/z/ecc/android/sdk/internal/jni/VotingRustBackend.kt index 95c1fc521..1f1c527d9 100644 --- a/backend-lib/src/main/java/cash/z/ecc/android/sdk/internal/jni/VotingRustBackend.kt +++ b/backend-lib/src/main/java/cash/z/ecc/android/sdk/internal/jni/VotingRustBackend.kt @@ -1,9 +1,12 @@ package cash.z.ecc.android.sdk.internal.jni import androidx.annotation.Keep +import cash.z.ecc.android.sdk.internal.SdkDispatchers +import cash.z.ecc.android.sdk.internal.model.voting.JniBundleSetupResult import cash.z.ecc.android.sdk.internal.model.voting.JniRoundState import cash.z.ecc.android.sdk.internal.model.voting.JniRoundSummary import cash.z.ecc.android.sdk.internal.model.voting.JniVoteRecord +import cash.z.ecc.android.sdk.internal.model.voting.JniVotingHotkey import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.sync.withLock @@ -13,14 +16,47 @@ import kotlinx.coroutines.withContext @Suppress("TooManyFunctions", "LongParameterList") class VotingRustBackend private constructor() { @Throws(RuntimeException::class) - fun computeShareNullifier( + suspend fun computeShareNullifier( voteCommitment: ByteArray, shareIndex: Int, blind: ByteArray - ): ByteArray = computeShareNullifierNative(voteCommitment, shareIndex, blind) + ): ByteArray = + withContext(Dispatchers.IO) { + computeShareNullifierNative(voteCommitment, shareIndex, blind) + } - suspend fun openVotingDb(dbPath: String, walletId: String): VotingDb = + @Throws(RuntimeException::class) + suspend fun computeBundleSetup(notesJson: String): JniBundleSetupResult = + withContext(Dispatchers.IO) { + computeBundleSetupNative(notesJson) + ?: error("computeBundleSetup returned null") + } + + @Throws(RuntimeException::class) + suspend fun warmProvingCaches() = + withContext(Dispatchers.IO) { + warmProvingCachesNative() + } + + @Throws(RuntimeException::class) + suspend fun extractPcztSighash(pcztBytes: ByteArray): ByteArray = + withContext(Dispatchers.IO) { + extractPcztSighashNative(pcztBytes) + ?: error("extractPcztSighash returned null") + } + + @Throws(RuntimeException::class) + suspend fun extractSpendAuthSig( + signedPcztBytes: ByteArray, + actionIndex: Int + ): ByteArray = withContext(Dispatchers.IO) { + extractSpendAuthSigNative(signedPcztBytes, actionIndex) + ?: error("extractSpendAuthSig returned null") + } + + suspend fun openVotingDb(dbPath: String, walletId: String): VotingDb = + withContext(SdkDispatchers.DATABASE_IO) { openVotingDbNative(dbPath, walletId).let { dbHandle -> check(dbHandle != 0L) { "openVotingDb failed for dbPath=$dbPath" @@ -38,7 +74,7 @@ class VotingRustBackend private constructor() { suspend fun close() { accessMutex.withLock { dbHandle?.let { handle -> - withContext(Dispatchers.IO) { + withContext(SdkDispatchers.DATABASE_IO) { closeVotingDbNative(handle) } dbHandle = null @@ -74,6 +110,10 @@ class VotingRustBackend private constructor() { suspend fun listRounds(): Array = withHandle { handle -> listRoundsNative(handle) } + @Throws(RuntimeException::class) + suspend fun getBundleCount(roundId: String): Int = + withHandle { handle -> getBundleCountNative(handle, roundId) } + @Throws(RuntimeException::class) suspend fun getVotes(roundId: String): Array = withHandle { handle -> getVotesNative(handle, roundId) } @@ -89,13 +129,62 @@ class VotingRustBackend private constructor() { ): Long = withHandle { handle -> deleteSkippedBundlesNative(handle, roundId, keepCount) } + @Throws(RuntimeException::class) + suspend fun setupBundles( + roundId: String, + notesJson: String + ): JniBundleSetupResult = + withHandle { handle -> + setupBundlesNative(handle, roundId, notesJson) + ?: error("setupBundles returned null for roundId=$roundId") + } + + @Throws(RuntimeException::class) + suspend fun generateHotkey( + roundId: String, + seed: ByteArray + ): JniVotingHotkey = + withHandle { handle -> + generateHotkeyNative(handle, roundId, seed) + ?: error("generateHotkey returned null for roundId=$roundId") + } + + @Throws(RuntimeException::class) + suspend fun buildGovernancePcztJson( + roundId: String, + bundleIndex: Int, + ufvk: String, + networkId: Int, + accountIndex: Int, + notesJson: String, + walletSeed: ByteArray, + seedFingerprint: ByteArray, + roundName: String, + addressIndex: Int + ): String = + withHandle { handle -> + buildGovernancePcztJsonNative( + handle, + roundId, + bundleIndex, + ufvk, + networkId, + accountIndex, + notesJson, + walletSeed, + seedFingerprint, + roundName, + addressIndex + ) ?: error("buildGovernancePczt returned null") + } + private suspend fun withHandle(block: (Long) -> T): T = accessMutex.withLock { val handle = checkNotNull(dbHandle) { "Voting DB handle is closed" } - withContext(Dispatchers.IO) { + withContext(SdkDispatchers.DATABASE_IO) { block(handle) } } @@ -116,6 +205,10 @@ class VotingRustBackend private constructor() { blind: ByteArray ): ByteArray + @JvmStatic + @Throws(RuntimeException::class) + private external fun warmProvingCachesNative() + @JvmStatic @Throws(RuntimeException::class) private external fun openVotingDbNative(dbPath: String, walletId: String): Long @@ -144,6 +237,10 @@ class VotingRustBackend private constructor() { @Throws(RuntimeException::class) private external fun listRoundsNative(dbHandle: Long): Array + @JvmStatic + @Throws(RuntimeException::class) + private external fun getBundleCountNative(dbHandle: Long, roundId: String): Int + @JvmStatic @Throws(RuntimeException::class) private external fun getVotesNative(dbHandle: Long, roundId: String): Array @@ -159,5 +256,52 @@ class VotingRustBackend private constructor() { roundId: String, keepCount: Int ): Long + + @JvmStatic + @Throws(RuntimeException::class) + private external fun computeBundleSetupNative(notesJson: String): JniBundleSetupResult? + + @JvmStatic + @Throws(RuntimeException::class) + private external fun setupBundlesNative( + dbHandle: Long, + roundId: String, + notesJson: String + ): JniBundleSetupResult? + + @JvmStatic + @Throws(RuntimeException::class) + private external fun generateHotkeyNative( + dbHandle: Long, + roundId: String, + seed: ByteArray + ): JniVotingHotkey? + + @JvmStatic + @Throws(RuntimeException::class) + private external fun buildGovernancePcztJsonNative( + dbHandle: Long, + roundId: String, + bundleIndex: Int, + ufvk: String, + networkId: Int, + accountIndex: Int, + notesJson: String, + walletSeed: ByteArray, + seedFingerprint: ByteArray, + roundName: String, + addressIndex: Int + ): String? + + @JvmStatic + @Throws(RuntimeException::class) + private external fun extractPcztSighashNative(pcztBytes: ByteArray): ByteArray? + + @JvmStatic + @Throws(RuntimeException::class) + private external fun extractSpendAuthSigNative( + signedPcztBytes: ByteArray, + actionIndex: Int + ): ByteArray? } } diff --git a/backend-lib/src/main/java/cash/z/ecc/android/sdk/internal/model/voting/JniVotingModels.kt b/backend-lib/src/main/java/cash/z/ecc/android/sdk/internal/model/voting/JniVotingModels.kt index cf017d986..ebbb3f14d 100644 --- a/backend-lib/src/main/java/cash/z/ecc/android/sdk/internal/model/voting/JniVotingModels.kt +++ b/backend-lib/src/main/java/cash/z/ecc/android/sdk/internal/model/voting/JniVotingModels.kt @@ -1,6 +1,44 @@ package cash.z.ecc.android.sdk.internal.model.voting import androidx.annotation.Keep +import cash.z.ecc.android.sdk.internal.jni.JNI_HOTKEY_PUBLIC_KEY_BYTES_SIZE + +@ConsistentCopyVisibility +data class HotkeyPublicKey internal constructor( + val value: ByteArray +) { + init { + require(value.size == JNI_HOTKEY_PUBLIC_KEY_BYTES_SIZE) { + "HotkeyPublicKey must be $JNI_HOTKEY_PUBLIC_KEY_BYTES_SIZE bytes, got ${value.size}" + } + } + + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (other !is HotkeyPublicKey) return false + return value.contentEquals(other.value) + } + + override fun hashCode(): Int = value.contentHashCode() + + override fun toString(): String = "HotkeyPublicKey(${value.toHexString()})" + + companion object { + internal fun new(bytes: ByteArray) = HotkeyPublicKey(bytes) + } +} + +private fun ByteArray.toHexString() = joinToString("") { "%02x".format(it) } + +@Keep +@ConsistentCopyVisibility +data class JniVotingHotkey internal constructor( + val publicKey: HotkeyPublicKey, + val address: String +) { + internal constructor(pk: ByteArray, addr: String) : + this(HotkeyPublicKey.new(pk), addr) +} // Must match PHASE_* constants in backend-lib/src/main/rust/voting/helpers.rs. internal const val JNI_ROUND_PHASE_INITIALIZED = 0 @@ -9,6 +47,16 @@ internal const val JNI_ROUND_PHASE_DELEGATION_CONSTRUCTED = 2 internal const val JNI_ROUND_PHASE_DELEGATION_PROVED = 3 internal const val JNI_ROUND_PHASE_VOTE_READY = 4 +@Keep +data class JniBundleSetupResult( + val bundleCount: Int, + val eligibleWeight: Long, + val bundleWeights: List +) { + internal constructor(bundleCount: Int, eligibleWeight: Long, bundleWeights: LongArray) : + this(bundleCount, eligibleWeight, bundleWeights.toList()) +} + @Keep data class JniRoundState( val roundId: String, diff --git a/backend-lib/src/main/rust/voting.rs b/backend-lib/src/main/rust/voting.rs index bbea42485..fb01a920e 100644 --- a/backend-lib/src/main/rust/voting.rs +++ b/backend-lib/src/main/rust/voting.rs @@ -2,20 +2,25 @@ use anyhow::anyhow; use jni::{ - JNIEnv, objects::{JByteArray, JClass, JObject, JString, JValue}, - sys::{jboolean, jbyteArray, jint, jlong, jobject, jobjectArray}, + sys::{jboolean, jbyteArray, jint, jlong, jobject, jobjectArray, jstring}, + JNIEnv, }; +use orchard::keys::Scope; +use secrecy::{ExposeSecret, SecretVec}; use std::{ collections::HashMap, sync::{ - Arc, Mutex, OnceLock, atomic::{AtomicI64, Ordering}, + Arc, Mutex, OnceLock, }, }; +use zcash_client_backend::keys::{UnifiedFullViewingKey, UnifiedSpendingKey}; +use zcash_protocol::consensus::{BranchId, Network, NetworkConstants}; use zcash_voting as voting; use voting::storage::{RoundPhase, RoundState, RoundSummary, VoteRecord, VotingDb}; +use voting::types::{GovernancePczt, NoteInfo}; use crate::utils::{ catch_unwind, exception::unwrap_exc_or, java_nullable_string_to_rust, java_string_to_rust, @@ -23,6 +28,10 @@ use crate::utils::{ }; mod db; +mod delegation; mod helpers; +mod json; +mod notes; mod rounds; mod share_tracking; +mod util; diff --git a/backend-lib/src/main/rust/voting/db.rs b/backend-lib/src/main/rust/voting/db.rs index 670db4747..cfb7bd641 100644 --- a/backend-lib/src/main/rust/voting/db.rs +++ b/backend-lib/src/main/rust/voting/db.rs @@ -1,21 +1,65 @@ use super::*; +use std::{ + ops::Deref, + sync::{MutexGuard, Weak}, +}; static NEXT_DB_HANDLE: AtomicI64 = AtomicI64::new(1); -static DB_REGISTRY: OnceLock>>> = OnceLock::new(); +static DB_REGISTRY: OnceLock>>> = OnceLock::new(); +static DB_BY_KEY: OnceLock>>> = OnceLock::new(); -fn registry() -> &'static Mutex>> { +#[derive(Clone, Eq, Hash, PartialEq)] +struct DbKey { + path: String, + wallet_id: String, +} + +pub(super) struct VotingDbHandle { + db: VotingDb, + access_mutex: Mutex<()>, +} + +impl VotingDbHandle { + fn open(path: &str, wallet_id: &str) -> anyhow::Result { + let db = VotingDb::open(path).map_err(|e| anyhow!("VotingDb::open failed: {}", e))?; + db.set_wallet_id(wallet_id); + + Ok(Self { + db, + access_mutex: Mutex::new(()), + }) + } + + pub(super) fn access_lock(&self) -> anyhow::Result> { + self.access_mutex + .lock() + .map_err(|_| anyhow!("voting DB access mutex poisoned")) + } +} + +impl Deref for VotingDbHandle { + type Target = VotingDb; + + fn deref(&self) -> &Self::Target { + &self.db + } +} + +fn registry() -> &'static Mutex>> { DB_REGISTRY.get_or_init(|| Mutex::new(HashMap::new())) } +fn db_by_key() -> &'static Mutex>> { + DB_BY_KEY.get_or_init(|| Mutex::new(HashMap::new())) +} + fn next_handle() -> anyhow::Result { NEXT_DB_HANDLE - .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |id| { - id.checked_add(1).filter(|next| *next > 0) - }) + .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |id| id.checked_add(1)) .map_err(|_| anyhow!("voting DB handle space exhausted")) } -pub(super) fn db_from_handle(handle: jlong) -> anyhow::Result> { +pub(super) fn db_from_handle(handle: jlong) -> anyhow::Result> { if handle <= 0 { return Err(anyhow!("Voting DB handle must be positive, got {handle}")); } @@ -28,6 +72,29 @@ pub(super) fn db_from_handle(handle: jlong) -> anyhow::Result> { .ok_or_else(|| anyhow!("Voting DB handle is closed or unknown: {handle}")) } +fn open_managed_db(path: &str, wallet_id: &str) -> anyhow::Result> { + if path == ":memory:" { + return Ok(Arc::new(VotingDbHandle::open(path, wallet_id)?)); + } + + let key = DbKey { + path: path.to_string(), + wallet_id: wallet_id.to_string(), + }; + let mut dbs = db_by_key() + .lock() + .map_err(|_| anyhow!("voting DB key registry mutex poisoned"))?; + dbs.retain(|_, db| db.strong_count() > 0); + + if let Some(db) = dbs.get(&key).and_then(Weak::upgrade) { + return Ok(db); + } + + let db = Arc::new(VotingDbHandle::open(path, wallet_id)?); + dbs.insert(key, Arc::downgrade(&db)); + Ok(db) +} + #[unsafe(no_mangle)] pub extern "C" fn Java_cash_z_ecc_android_sdk_internal_jni_VotingRustBackend_openVotingDbNative< 'local, @@ -44,13 +111,12 @@ pub extern "C" fn Java_cash_z_ecc_android_sdk_internal_jni_VotingRustBackend_ope return Err(anyhow!("walletId must not be empty")); } - let db = VotingDb::open(&path).map_err(|e| anyhow!("VotingDb::open failed: {}", e))?; - db.set_wallet_id(&wallet_id); + let db = open_managed_db(&path, &wallet_id)?; let handle = next_handle()?; registry() .lock() .map_err(|_| anyhow!("voting DB registry mutex poisoned"))? - .insert(handle, Arc::new(db)); + .insert(handle, db); Ok(handle) }); diff --git a/backend-lib/src/main/rust/voting/delegation.rs b/backend-lib/src/main/rust/voting/delegation.rs new file mode 100644 index 000000000..7cf2c8aa7 --- /dev/null +++ b/backend-lib/src/main/rust/voting/delegation.rs @@ -0,0 +1,241 @@ +use super::db::*; +use super::helpers::*; +use super::json::*; +use super::*; + +#[unsafe(no_mangle)] +pub extern "C" fn Java_cash_z_ecc_android_sdk_internal_jni_VotingRustBackend_buildGovernancePcztJsonNative< + 'local, +>( + mut env: JNIEnv<'local>, + _: JClass<'local>, + db_handle: jlong, + round_id: JString<'local>, + bundle_index: jint, + ufvk: JString<'local>, + network_id: jint, + account_index: jint, + notes_json: JString<'local>, + wallet_seed: JByteArray<'local>, + seed_fingerprint: JByteArray<'local>, + round_name: JString<'local>, + address_index: jint, +) -> jstring { + let res = catch_unwind(&mut env, |env| { + let db = db_from_handle(db_handle)?; + let _access_lock = db.access_lock()?; + let network = network_from_id(network_id)?; + let bundle_index = jint_to_u32(bundle_index, "bundle_index")?; + let account_index = jint_to_u32(account_index, "account_index")?; + let address_index = jint_to_u32(address_index, "address_index")?; + let ufvk_str = java_string_to_rust(env, &ufvk)?; + let fvk_bytes = orchard_fvk_bytes(&ufvk_str, network)?; + + let seed_bytes = + java_secret_bytes_at_least(env, &wallet_seed, "walletSeed", PROTOCOL_FIELD_BYTES)?; + let derived_fvk_bytes = + orchard_fvk_bytes_from_wallet_seed(seed_bytes.expose_secret(), network, account_index)?; + if derived_fvk_bytes != fvk_bytes { + return Err(anyhow!( + "ufvk does not match walletSeed for network_id={network_id} account_index={account_index}" + )); + } + let hotkey_raw_address = hotkey_orchard_raw_address_from_wallet_seed( + seed_bytes.expose_secret(), + network, + account_index, + address_index, + )?; + let seed_fingerprint = java_bytes32(env, &seed_fingerprint, "seedFingerprint")?; + + let json_notes: Vec = json_from_jstring(env, ¬es_json, "notesJson")?; + let notes: Vec = json_notes + .into_iter() + .map(NoteInfo::try_from) + .collect::>()?; + let bundle_notes = bundled_notes_for_index(¬es, bundle_index)?; + + let round_id = java_string_to_rust(env, &round_id)?; + require_round_phase_for_delegation_construction(&db, &round_id)?; + let round_name = java_string_to_rust(env, &round_name)?; + let pczt = db + .build_governance_pczt( + &round_id, + bundle_index, + &bundle_notes, + &fvk_bytes, + &hotkey_raw_address, + nu6_branch_id(), + network.coin_type(), + &seed_fingerprint, + account_index, + &round_name, + address_index, + ) + .map_err(|e| anyhow!("build_governance_pczt: {}", e))?; + update_round_phase_forward(&db, &round_id, RoundPhase::DelegationConstructed)?; + + json_to_jstring(env, &JsonGovernancePczt::try_from(pczt)?) + }); + unwrap_exc_or(&mut env, res, std::ptr::null_mut()) +} + +#[unsafe(no_mangle)] +pub extern "C" fn Java_cash_z_ecc_android_sdk_internal_jni_VotingRustBackend_extractPcztSighashNative< + 'local, +>( + mut env: JNIEnv<'local>, + _: JClass<'local>, + pczt_bytes: JByteArray<'local>, +) -> jbyteArray { + let res = catch_unwind(&mut env, |env| { + let bytes = java_bytes(env, &pczt_bytes, "pcztBytes")?; + let sighash = voting::action::extract_pczt_sighash(&bytes) + .map_err(|e| anyhow!("extract_pczt_sighash: {}", e))?; + Ok(env.byte_array_from_slice(&sighash)?.into_raw()) + }); + unwrap_exc_or(&mut env, res, std::ptr::null_mut()) +} + +#[unsafe(no_mangle)] +pub extern "C" fn Java_cash_z_ecc_android_sdk_internal_jni_VotingRustBackend_extractSpendAuthSigNative< + 'local, +>( + mut env: JNIEnv<'local>, + _: JClass<'local>, + signed_pczt_bytes: JByteArray<'local>, + action_index: jint, +) -> jbyteArray { + let res = catch_unwind(&mut env, |env| { + let bytes = java_bytes(env, &signed_pczt_bytes, "signedPcztBytes")?; + let action_index = jint_to_usize(action_index, "action_index")?; + let sig = extract_indexed_spend_auth_sig(&bytes, action_index)?; + Ok(env.byte_array_from_slice(&sig)?.into_raw()) + }); + unwrap_exc_or(&mut env, res, std::ptr::null_mut()) +} + +fn extract_indexed_spend_auth_sig( + signed_pczt_bytes: &[u8], + action_index: usize, +) -> anyhow::Result<[u8; 64]> { + let pczt = pczt::Pczt::parse(signed_pczt_bytes).map_err(|e| { + anyhow!( + "extract_spend_auth_sig: failed to parse signed PCZT: {:?}", + e + ) + })?; + let actions = pczt.orchard().actions(); + if action_index < actions.len() { + if let Some(sig) = actions[action_index].spend().spend_auth_sig() { + return Ok(*sig); + } + + return Err(anyhow!( + "extract_spend_auth_sig: action {action_index} has no spend_auth_sig" + )); + } + Err(anyhow!( + "extract_spend_auth_sig: action_index {action_index} out of bounds for {} actions", + actions.len() + )) +} + +#[cfg(test)] +mod tests { + use super::*; + use orchard::keys::{FullViewingKey, Scope, SpendAuthorizingKey, SpendingKey}; + use voting::types::VotingRoundParams; + + #[test] + fn extract_spend_auth_sig_accepts_signed_governance_pczt() { + let spending_key = SpendingKey::from_bytes([0x42; 32]).expect("valid spending key"); + let fvk = FullViewingKey::from(&spending_key); + let hotkey_spending_key = SpendingKey::from_bytes([0x43; 32]).expect("valid hotkey"); + let hotkey_fvk = FullViewingKey::from(&hotkey_spending_key); + let hotkey_address = hotkey_fvk + .address_at(0u32, Scope::External) + .to_raw_address_bytes() + .to_vec(); + let result = voting::action::build_governance_pczt( + &[note_info()], + &round_params(), + &fvk.to_bytes().to_vec(), + &hotkey_address, + nu6_branch_id(), + Network::TestNetwork.coin_type(), + &[0xAA; 32], + 0, + "Test Round", + ) + .expect("governance PCZT"); + + let pczt = pczt::Pczt::parse(&result.pczt_bytes).expect("parse PCZT"); + let mut signer = pczt::roles::signer::Signer::new(pczt).expect("signer"); + let spend_authorizing_key = SpendAuthorizingKey::from(&spending_key); + signer + .sign_orchard(result.action_index, &spend_authorizing_key) + .expect("sign orchard action"); + let signed_pczt = signer.finish().serialize(); + let sig = extract_indexed_spend_auth_sig(&signed_pczt, result.action_index).unwrap(); + + assert_ne!(sig, [0u8; 64]); + } + + #[test] + fn extract_spend_auth_sig_rejects_unsigned_governance_pczt() { + let result = test_governance_pczt(); + let err = + extract_indexed_spend_auth_sig(&result.pczt_bytes, result.action_index).unwrap_err(); + + assert!(err.to_string().contains("has no spend_auth_sig")); + } + + fn test_governance_pczt() -> GovernancePczt { + let spending_key = SpendingKey::from_bytes([0x42; 32]).expect("valid spending key"); + let fvk = FullViewingKey::from(&spending_key); + let hotkey_spending_key = SpendingKey::from_bytes([0x43; 32]).expect("valid hotkey"); + let hotkey_fvk = FullViewingKey::from(&hotkey_spending_key); + let hotkey_address = hotkey_fvk + .address_at(0u32, Scope::External) + .to_raw_address_bytes() + .to_vec(); + voting::action::build_governance_pczt( + &[note_info()], + &round_params(), + &fvk.to_bytes().to_vec(), + &hotkey_address, + nu6_branch_id(), + Network::TestNetwork.coin_type(), + &[0xAA; 32], + 0, + "Test Round", + ) + .expect("governance PCZT") + } + + fn note_info() -> NoteInfo { + NoteInfo { + commitment: vec![1; PROTOCOL_FIELD_BYTES], + nullifier: vec![2; PROTOCOL_FIELD_BYTES], + value: 15_000_000, + position: 0, + diversifier: vec![0; 11], + rho: vec![0; PROTOCOL_FIELD_BYTES], + rseed: vec![0; PROTOCOL_FIELD_BYTES], + scope: 0, + ufvk_str: String::new(), + } + } + + fn round_params() -> VotingRoundParams { + VotingRoundParams { + vote_round_id: "0101010101010101010101010101010101010101010101010101010101010101" + .to_string(), + snapshot_height: 100_000, + ea_pk: vec![0xEA; PROTOCOL_FIELD_BYTES], + nc_root: vec![0x01; PROTOCOL_FIELD_BYTES], + nullifier_imt_root: vec![0x02; PROTOCOL_FIELD_BYTES], + } + } +} diff --git a/backend-lib/src/main/rust/voting/helpers.rs b/backend-lib/src/main/rust/voting/helpers.rs index 7960c98a9..41752df33 100644 --- a/backend-lib/src/main/rust/voting/helpers.rs +++ b/backend-lib/src/main/rust/voting/helpers.rs @@ -9,11 +9,24 @@ const PHASE_VOTE_READY: u32 = 4; const JNI_ROUND_SUMMARY: &str = "cash/z/ecc/android/sdk/internal/model/voting/JniRoundSummary"; const JNI_VOTE_RECORD: &str = "cash/z/ecc/android/sdk/internal/model/voting/JniVoteRecord"; +const JNI_VOTING_HOTKEY: &str = "cash/z/ecc/android/sdk/internal/model/voting/JniVotingHotkey"; +const JNI_BUNDLE_SETUP_RESULT: &str = + "cash/z/ecc/android/sdk/internal/model/voting/JniBundleSetupResult"; +// Must match JniVotingHotkey(ByteArray, String) in JniVotingModels.kt. +const JNI_VOTING_HOTKEY_CTOR_SIG: &str = "([BLjava/lang/String;)V"; +// Must match JniBundleSetupResult(Int, Long, LongArray) in JniVotingModels.kt. +const JNI_BUNDLE_SETUP_RESULT_CTOR_SIG: &str = "(IJ[J)V"; + +pub(super) const ORCHARD_RAW_ADDRESS_BYTES: usize = 43; +pub(super) const ORCHARD_FVK_BYTES: usize = 96; pub(super) const PROTOCOL_FIELD_BYTES: usize = 32; pub(super) const VOTE_COMMITMENT_BYTES: usize = PROTOCOL_FIELD_BYTES; pub(super) const BLIND_BYTES: usize = PROTOCOL_FIELD_BYTES; -pub(super) const SHARE_NULLIFIER_BYTES: usize = 32; +pub(super) const SHARE_NULLIFIER_BYTES: usize = PROTOCOL_FIELD_BYTES; +pub(super) const HOTKEY_PUBLIC_KEY_BYTES: usize = PROTOCOL_FIELD_BYTES; +pub(super) const NETWORK_ID_TESTNET: jint = 0; +pub(super) const NETWORK_ID_MAINNET: jint = 1; struct JniRoundSummaryPayload { round_id: String, @@ -37,11 +50,19 @@ pub(super) fn jlong_to_u64(value: jlong, field: &str) -> anyhow::Result { u64::try_from(value).map_err(|_| anyhow!("{field} must be non-negative, got {value}")) } -fn u32_to_jint(value: u32, field: &str) -> anyhow::Result { +pub(super) fn jint_to_usize(value: jint, field: &str) -> anyhow::Result { + usize::try_from(value).map_err(|_| anyhow!("{field} must be non-negative, got {value}")) +} + +pub(super) fn u32_to_jint(value: u32, field: &str) -> anyhow::Result { jint::try_from(value).map_err(|_| anyhow!("{field} exceeds signed Int range: {value}")) } -fn u64_to_jlong(value: u64, field: &str) -> anyhow::Result { +pub(super) fn usize_to_jint(value: usize, field: &str) -> anyhow::Result { + jint::try_from(value).map_err(|_| anyhow!("{field} exceeds signed Int range: {value}")) +} + +pub(super) fn u64_to_jlong(value: u64, field: &str) -> anyhow::Result { jlong::try_from(value).map_err(|_| anyhow!("{field} exceeds signed Long range: {value}")) } @@ -56,6 +77,31 @@ pub(super) fn require_len(bytes: Vec, field: &str, expected: usize) -> anyho } } +pub(super) fn require_min_len( + bytes: Vec, + field: &str, + minimum: usize, +) -> anyhow::Result> { + if bytes.len() >= minimum { + Ok(bytes) + } else { + Err(anyhow!( + "{field} must be at least {minimum} bytes, got {}", + bytes.len() + )) + } +} + +pub(super) fn require_32( + bytes: Vec, + field: &str, +) -> anyhow::Result<[u8; PROTOCOL_FIELD_BYTES]> { + let bytes = require_len(bytes, field, PROTOCOL_FIELD_BYTES)?; + bytes + .try_into() + .map_err(|_| anyhow!("{field} must be exactly {PROTOCOL_FIELD_BYTES} bytes")) +} + pub(super) fn java_bytes( env: &mut JNIEnv<'_>, array: &JByteArray<'_>, @@ -100,11 +146,95 @@ pub(super) fn round_phase_to_u32(phase: RoundPhase) -> u32 { } } +pub(super) fn java_secret_bytes_at_least( + env: &mut JNIEnv<'_>, + array: &JByteArray<'_>, + field: &str, + minimum: usize, +) -> anyhow::Result> { + require_min_len(java_bytes(env, array, field)?, field, minimum).map(SecretVec::new) +} + +pub(super) fn java_bytes32( + env: &mut JNIEnv<'_>, + array: &JByteArray<'_>, + field: &str, +) -> anyhow::Result<[u8; PROTOCOL_FIELD_BYTES]> { + require_32(java_bytes(env, array, field)?, field) +} + +pub(super) fn network_from_id(id: jint) -> anyhow::Result { + match id { + NETWORK_ID_TESTNET => Ok(Network::TestNetwork), + NETWORK_ID_MAINNET => Ok(Network::MainNetwork), + _ => Err(anyhow!("invalid network_id {}", id)), + } +} + +pub(super) fn hotkey_orchard_raw_address_from_wallet_seed( + wallet_seed: &[u8], + network: Network, + account_index: u32, + address_index: u32, +) -> anyhow::Result> { + let account_id = zip32::AccountId::try_from(account_index) + .map_err(|_| anyhow!("invalid account_index {}", account_index))?; + let usk = UnifiedSpendingKey::from_seed(&network, wallet_seed, account_id) + .map_err(|e| anyhow!("failed to derive hotkey USK from wallet seed: {}", e))?; + let fvk = usk.to_unified_full_viewing_key(); + let orchard_fvk = fvk + .orchard() + .ok_or_else(|| anyhow!("hotkey UFVK has no Orchard component"))?; + // voting-circuits treats address_index as the diversifier index for the + // external Orchard scope when reconstructing the hotkey address for ZKP #2. + let addr = orchard_fvk.address_at(address_index, Scope::External); + require_len( + addr.to_raw_address_bytes().to_vec(), + "hotkey_raw_address", + ORCHARD_RAW_ADDRESS_BYTES, + ) +} + +pub(super) fn orchard_fvk_bytes_from_wallet_seed( + wallet_seed: &[u8], + network: Network, + account_index: u32, +) -> anyhow::Result> { + let account_id = zip32::AccountId::try_from(account_index) + .map_err(|_| anyhow!("invalid account_index {}", account_index))?; + let usk = UnifiedSpendingKey::from_seed(&network, wallet_seed, account_id) + .map_err(|e| anyhow!("failed to derive USK from wallet seed: {}", e))?; + let ufvk = usk.to_unified_full_viewing_key(); + let orchard_fvk = ufvk + .orchard() + .ok_or_else(|| anyhow!("derived UFVK has no Orchard component"))?; + require_len( + orchard_fvk.to_bytes().to_vec(), + "derived_orchard_fvk", + ORCHARD_FVK_BYTES, + ) +} + +pub(super) fn orchard_fvk_bytes(ufvk_str: &str, network: Network) -> anyhow::Result> { + let ufvk = UnifiedFullViewingKey::decode(&network, ufvk_str) + .map_err(|e| anyhow!("failed to decode UFVK: {}", e))?; + let fvk = ufvk + .orchard() + .ok_or_else(|| anyhow!("UFVK has no Orchard component"))?; + require_len(fvk.to_bytes().to_vec(), "orchard_fvk", ORCHARD_FVK_BYTES) +} + +// NU6 branch ID used by the governance PCZT signer path. Revisit this when +// the voting transaction format moves to a later consensus branch. +pub(super) fn nu6_branch_id() -> u32 { + BranchId::Nu6.into() +} + pub(super) fn make_jni_round_state<'local>( env: &mut JNIEnv<'local>, state: RoundState, ) -> anyhow::Result { - let phase = round_phase_to_u32(state.phase) as i32; + let phase = round_phase_to_u32(state.phase); let class = env.find_class("cash/z/ecc/android/sdk/internal/model/voting/JniRoundState")?; let round_id_obj: JObject<'local> = env.new_string(&state.round_id)?.into(); let hotkey_obj: JObject<'local> = match &state.hotkey_address { @@ -113,7 +243,11 @@ pub(super) fn make_jni_round_state<'local>( }; let long_class = env.find_class("java/lang/Long")?; let weight_obj: JObject<'local> = match state.delegated_weight { - Some(w) => env.new_object(&long_class, "(J)V", &[JValue::Long(w as i64)])?, + Some(w) => env.new_object( + &long_class, + "(J)V", + &[JValue::Long(u64_to_jlong(w, "delegated_weight")?)], + )?, None => JObject::null(), }; let obj = env.new_object( @@ -123,8 +257,8 @@ pub(super) fn make_jni_round_state<'local>( "(Ljava/lang/String;IJLjava/lang/String;Ljava/lang/Long;Z)V", &[ JValue::Object(&round_id_obj), - JValue::Int(phase), - JValue::Long(state.snapshot_height as i64), + JValue::Int(u32_to_jint(phase, "round_phase")?), + JValue::Long(u64_to_jlong(state.snapshot_height, "snapshot_height")?), JValue::Object(&hotkey_obj), JValue::Object(&weight_obj), JValue::Bool(state.proof_generated as jboolean), @@ -213,3 +347,242 @@ impl TryFrom for JniVoteRecordPayload { }) } } + +/// Builds the Kotlin hotkey JNI model after enforcing the expected key widths. +/// The secret key is intentionally not surfaced across JNI. +pub(super) fn make_jni_voting_hotkey<'local>( + env: &mut JNIEnv<'local>, + hotkey: voting::types::VotingHotkey, +) -> anyhow::Result { + let class = env.find_class(JNI_VOTING_HOTKEY)?; + let secret_key = SecretVec::new(hotkey.secret_key); + let secret_key_len = secret_key.expose_secret().len(); + if secret_key_len != PROTOCOL_FIELD_BYTES { + return Err(anyhow!( + "hotkey_secret_key must be exactly {PROTOCOL_FIELD_BYTES} bytes, got {secret_key_len}" + )); + } + let public_key = require_len( + hotkey.public_key, + "hotkey_public_key", + HOTKEY_PUBLIC_KEY_BYTES, + )?; + let pk_obj: JObject<'local> = env.byte_array_from_slice(&public_key)?.into(); + let addr_obj: JObject<'local> = env.new_string(&hotkey.address)?.into(); + let obj = env.new_object( + &class, + JNI_VOTING_HOTKEY_CTOR_SIG, + &[JValue::Object(&pk_obj), JValue::Object(&addr_obj)], + )?; + Ok(obj.into_raw()) +} + +/// Builds the Kotlin bundle setup JNI model with width-checked Java primitives. +pub(super) fn make_jni_bundle_setup_result<'local>( + env: &mut JNIEnv<'local>, + count: u32, + weight: u64, + bundle_weights: &[u64], +) -> anyhow::Result { + let class = env.find_class(JNI_BUNDLE_SETUP_RESULT)?; + let weights = bundle_weights + .iter() + .enumerate() + .map(|(index, weight)| u64_to_jlong(*weight, &format!("bundle_weights[{index}]"))) + .collect::>>()?; + let weights_array = + env.new_long_array(usize_to_jint(weights.len(), "bundle_weights length")?)?; + env.set_long_array_region(&weights_array, 0, &weights)?; + let weights_array_obj = JObject::from(weights_array); + let obj = env.new_object( + &class, + JNI_BUNDLE_SETUP_RESULT_CTOR_SIG, + &[ + JValue::Int(u32_to_jint(count, "bundle_count")?), + JValue::Long(u64_to_jlong(weight, "eligible_weight")?), + JValue::Object(&weights_array_obj), + ], + )?; + Ok(obj.into_raw()) +} + +/// Runs the voting note chunker and returns total count, total eligible weight, +/// and each bundle's quantized voting weight. +pub(super) fn bundle_setup_from_notes(notes: &[NoteInfo]) -> anyhow::Result<(u32, u64, Vec)> { + let chunk_result = voting::types::chunk_notes(notes); + let bundle_weights = chunk_result + .bundles + .iter() + .map(|bundle| { + let total = bundle.iter().try_fold(0u64, |acc, note| { + acc.checked_add(note.value) + .ok_or_else(|| anyhow!("bundle note value overflows u64")) + })?; + Ok((total / voting::BALLOT_DIVISOR) * voting::BALLOT_DIVISOR) + }) + .collect::>>()?; + Ok(( + u32::try_from(chunk_result.bundles.len()) + .map_err(|_| anyhow!("bundle count is too large for u32"))?, + chunk_result.eligible_weight, + bundle_weights, + )) +} + +/// Recomputes deterministic note chunking and returns the requested bundle. +pub(super) fn bundled_notes_for_index( + notes: &[NoteInfo], + bundle_index: u32, +) -> anyhow::Result> { + let chunk_result = voting::types::chunk_notes(notes); + let bundle_index = usize::try_from(bundle_index) + .map_err(|_| anyhow!("bundle_index is too large for this platform: {bundle_index}"))?; + + chunk_result + .bundles + .get(bundle_index) + .cloned() + .ok_or_else(|| anyhow!("bundle_index {bundle_index} is not present in note bundle set")) +} + +/// Advances a round phase without allowing regressions; equal phases are +/// treated as idempotent. +pub(super) fn update_round_phase_forward( + db: &VotingDb, + round_id: &str, + phase: RoundPhase, +) -> anyhow::Result<()> { + let conn = db.conn(); + let wallet_id = db.wallet_id(); + let requested_rank = round_phase_to_u32(phase); + + let rows = conn + .execute( + "UPDATE rounds + SET phase = ?1 + WHERE round_id = ?2 + AND wallet_id = ?3 + AND phase < ?1", + rusqlite::params![phase as i32, round_id, wallet_id], + ) + .map_err(|e| anyhow!("update_round_phase: {}", e))?; + if rows > 0 { + return Ok(()); + } + + let current = voting::storage::queries::get_round_state(&conn, round_id, &wallet_id) + .map_err(|e| anyhow!("get_round_state after phase update: {}", e))? + .phase; + let current_rank = round_phase_to_u32(current); + if current_rank < requested_rank { + return Err(anyhow!( + "failed to advance round phase for {round_id}: current={current_rank}, requested={requested_rank}" + )); + } else if current_rank > requested_rank { + return Err(anyhow!( + "refusing to regress round phase for {round_id}: current={current_rank}, requested={requested_rank}" + )); + } + + Ok(()) +} + +/// Requires the hotkey step before PCZT construction and rejects calls after +/// later workflow phases have already begun. +pub(super) fn require_round_phase_for_delegation_construction( + db: &VotingDb, + round_id: &str, +) -> anyhow::Result<()> { + let conn = db.conn(); + let wallet_id = db.wallet_id(); + let current = voting::storage::queries::get_round_state(&conn, round_id, &wallet_id) + .map_err(|e| anyhow!("get_round_state before delegation construction: {}", e))? + .phase; + let current_rank = round_phase_to_u32(current); + let hotkey_rank = round_phase_to_u32(RoundPhase::HotkeyGenerated); + let constructed_rank = round_phase_to_u32(RoundPhase::DelegationConstructed); + + if current_rank < hotkey_rank { + return Err(anyhow!( + "round {round_id} must be HotkeyGenerated before building governance PCZT: current={current_rank}" + )); + } + + if current_rank > constructed_rank { + return Err(anyhow!( + "round {round_id} has already advanced beyond DelegationConstructed: current={current_rank}" + )); + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + const TEST_ROUND_ID: &str = "round-id"; + const TEST_WALLET_ID: &str = "wallet-id"; + + #[test] + fn hotkey_orchard_raw_address_uses_address_index() { + let seed = [0x42_u8; 64]; + + let index_zero = + hotkey_orchard_raw_address_from_wallet_seed(&seed, Network::TestNetwork, 0, 0).unwrap(); + let index_one = + hotkey_orchard_raw_address_from_wallet_seed(&seed, Network::TestNetwork, 0, 1).unwrap(); + + assert_eq!(ORCHARD_RAW_ADDRESS_BYTES, index_zero.len()); + assert_eq!(ORCHARD_RAW_ADDRESS_BYTES, index_one.len()); + assert_ne!(index_zero, index_one); + } + + #[test] + fn nu6_branch_id_comes_from_protocol_crate() { + assert_eq!(nu6_branch_id(), u32::from(BranchId::Nu6)); + } + + #[test] + fn update_round_phase_forward_is_idempotent() { + let db = test_db(); + + update_round_phase_forward(&db, TEST_ROUND_ID, RoundPhase::HotkeyGenerated) + .expect("first phase update"); + update_round_phase_forward(&db, TEST_ROUND_ID, RoundPhase::HotkeyGenerated) + .expect("idempotent phase update"); + + let state = db.get_round_state(TEST_ROUND_ID).expect("round state"); + assert_eq!(RoundPhase::HotkeyGenerated, state.phase); + } + + #[test] + fn update_round_phase_forward_rejects_regression() { + let db = test_db(); + + update_round_phase_forward(&db, TEST_ROUND_ID, RoundPhase::DelegationConstructed) + .expect("advance phase"); + let err = update_round_phase_forward(&db, TEST_ROUND_ID, RoundPhase::HotkeyGenerated) + .expect_err("regression rejected"); + + assert!(err.to_string().contains("refusing to regress round phase")); + } + + fn test_db() -> VotingDb { + let db = VotingDb::open(":memory:").expect("test DB"); + db.set_wallet_id(TEST_WALLET_ID); + db.init_round(&test_round_params(), None) + .expect("round initialized"); + db + } + + fn test_round_params() -> voting::types::VotingRoundParams { + voting::types::VotingRoundParams { + vote_round_id: TEST_ROUND_ID.to_string(), + snapshot_height: 100_000, + ea_pk: vec![0xEA; PROTOCOL_FIELD_BYTES], + nc_root: vec![0x01; PROTOCOL_FIELD_BYTES], + nullifier_imt_root: vec![0x02; PROTOCOL_FIELD_BYTES], + } + } +} diff --git a/backend-lib/src/main/rust/voting/json.rs b/backend-lib/src/main/rust/voting/json.rs new file mode 100644 index 000000000..ffda84cab --- /dev/null +++ b/backend-lib/src/main/rust/voting/json.rs @@ -0,0 +1,140 @@ +use super::helpers::*; +use super::*; +use serde::{Deserialize, Serialize}; + +const NOTE_SCOPE_EXTERNAL: u32 = 0; +const NOTE_SCOPE_INTERNAL: u32 = 1; +const ORCHARD_DIVERSIFIER_BYTES: usize = 11; + +pub(super) fn hex_enc(bytes: &[u8]) -> String { + hex::encode(bytes) +} + +pub(super) fn hex_dec(value: &str, field: &str) -> anyhow::Result> { + hex::decode(value).map_err(|e| anyhow!("field '{field}': invalid hex: {e}")) +} + +#[derive(Deserialize)] +pub(super) struct JsonNoteInfo { + pub(super) commitment: String, + pub(super) nullifier: String, + pub(super) value: u64, + pub(super) position: u64, + pub(super) diversifier: String, + pub(super) rho: String, + pub(super) rseed: String, + pub(super) scope: u32, + pub(super) ufvk_str: String, +} + +impl TryFrom for NoteInfo { + type Error = anyhow::Error; + + fn try_from(note: JsonNoteInfo) -> anyhow::Result { + let scope = require_note_scope(note.scope)?; + + Ok(NoteInfo { + commitment: require_len( + hex_dec(¬e.commitment, "commitment")?, + "commitment", + PROTOCOL_FIELD_BYTES, + )?, + nullifier: require_len( + hex_dec(¬e.nullifier, "nullifier")?, + "nullifier", + PROTOCOL_FIELD_BYTES, + )?, + value: note.value, + position: note.position, + diversifier: require_len( + hex_dec(¬e.diversifier, "diversifier")?, + "diversifier", + ORCHARD_DIVERSIFIER_BYTES, + )?, + rho: require_len(hex_dec(¬e.rho, "rho")?, "rho", PROTOCOL_FIELD_BYTES)?, + rseed: require_len( + hex_dec(¬e.rseed, "rseed")?, + "rseed", + PROTOCOL_FIELD_BYTES, + )?, + scope, + ufvk_str: note.ufvk_str, + }) + } +} + +fn require_note_scope(scope: u32) -> anyhow::Result { + match scope { + NOTE_SCOPE_EXTERNAL | NOTE_SCOPE_INTERNAL => Ok(scope), + _ => Err(anyhow!( + "scope must be {NOTE_SCOPE_EXTERNAL} (external) or {NOTE_SCOPE_INTERNAL} (internal), got {scope}" + )), + } +} + +#[derive(Serialize)] +pub(super) struct JsonGovernancePczt { + pub(super) pczt_bytes: String, + pub(super) rk: String, + pub(super) action_index: u32, + pub(super) pczt_sighash: String, +} + +impl TryFrom for JsonGovernancePczt { + type Error = anyhow::Error; + + fn try_from(pczt: GovernancePczt) -> anyhow::Result { + Ok(JsonGovernancePczt { + pczt_bytes: hex_enc(&pczt.pczt_bytes), + rk: hex_enc(&pczt.rk), + action_index: u32::try_from(pczt.action_index) + .map_err(|_| anyhow!("action_index is too large for u32: {}", pczt.action_index))?, + pczt_sighash: hex_enc(&pczt.pczt_sighash), + }) + } +} + +pub(super) fn json_to_jstring( + env: &mut JNIEnv<'_>, + value: &T, +) -> anyhow::Result { + let s = serde_json::to_string(value).map_err(|e| anyhow!("JSON serialization error: {}", e))?; + Ok(env.new_string(s)?.into_raw()) +} + +pub(super) fn json_from_jstring Deserialize<'de>>( + env: &mut JNIEnv<'_>, + value: &JString<'_>, + field: &str, +) -> anyhow::Result { + let s = java_string_to_rust(env, value)?; + serde_json::from_str(&s).map_err(|e| { + anyhow!( + "{field}: JSON parse error at line {}, column {}", + e.line(), + e.column() + ) + }) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn json_note_info_rejects_unknown_scope() { + let note = JsonNoteInfo { + commitment: hex::encode([1u8; PROTOCOL_FIELD_BYTES]), + nullifier: hex::encode([2u8; PROTOCOL_FIELD_BYTES]), + value: 13_000_000, + position: 0, + diversifier: hex::encode([0u8; ORCHARD_DIVERSIFIER_BYTES]), + rho: hex::encode([0u8; PROTOCOL_FIELD_BYTES]), + rseed: hex::encode([0u8; PROTOCOL_FIELD_BYTES]), + scope: 2, + ufvk_str: String::new(), + }; + + assert!(NoteInfo::try_from(note).is_err()); + } +} diff --git a/backend-lib/src/main/rust/voting/notes.rs b/backend-lib/src/main/rust/voting/notes.rs new file mode 100644 index 000000000..ac8e1665f --- /dev/null +++ b/backend-lib/src/main/rust/voting/notes.rs @@ -0,0 +1,87 @@ +use super::db::*; +use super::helpers::*; +use super::json::*; +use super::*; + +#[unsafe(no_mangle)] +pub extern "C" fn Java_cash_z_ecc_android_sdk_internal_jni_VotingRustBackend_computeBundleSetupNative< + 'local, +>( + mut env: JNIEnv<'local>, + _: JClass<'local>, + notes_json: JString<'local>, +) -> jobject { + let res = catch_unwind(&mut env, |env| { + let json_notes: Vec = json_from_jstring(env, ¬es_json, "notesJson")?; + let notes: Vec = json_notes + .into_iter() + .map(NoteInfo::try_from) + .collect::>()?; + let (count, weight, bundle_weights) = bundle_setup_from_notes(¬es)?; + make_jni_bundle_setup_result(env, count, weight, &bundle_weights) + }); + unwrap_exc_or(&mut env, res, JObject::null().into_raw()) +} + +#[unsafe(no_mangle)] +pub extern "C" fn Java_cash_z_ecc_android_sdk_internal_jni_VotingRustBackend_setupBundlesNative< + 'local, +>( + mut env: JNIEnv<'local>, + _: JClass<'local>, + db_handle: jlong, + round_id: JString<'local>, + notes_json: JString<'local>, +) -> jobject { + let res = catch_unwind(&mut env, |env| { + let db = db_from_handle(db_handle)?; + let _access_lock = db.access_lock()?; + let json_notes: Vec = json_from_jstring(env, ¬es_json, "notesJson")?; + let notes: Vec = json_notes + .into_iter() + .map(NoteInfo::try_from) + .collect::>()?; + let (expected_count, expected_weight, bundle_weights) = bundle_setup_from_notes(¬es)?; + let round_id = java_string_to_rust(env, &round_id)?; + let (count, weight) = db + .setup_bundles(&round_id, ¬es) + .map_err(|e| anyhow!("setup_bundles: {}", e))?; + if count != expected_count || weight != expected_weight { + // setup_bundles has already persisted the round's bundles. Treat a + // mismatch as an internal bug; callers must clear the round before retrying. + return Err(anyhow!( + "setup_bundles result mismatch after persisting bundles; call clearRound before retrying: db=({}, {}) chunk=({}, {})", + count, + weight, + expected_count, + expected_weight + )); + } + make_jni_bundle_setup_result(env, count, weight, &bundle_weights) + }); + unwrap_exc_or(&mut env, res, JObject::null().into_raw()) +} + +#[unsafe(no_mangle)] +pub extern "C" fn Java_cash_z_ecc_android_sdk_internal_jni_VotingRustBackend_generateHotkeyNative< + 'local, +>( + mut env: JNIEnv<'local>, + _: JClass<'local>, + db_handle: jlong, + round_id: JString<'local>, + seed: JByteArray<'local>, +) -> jobject { + let res = catch_unwind(&mut env, |env| { + let db = db_from_handle(db_handle)?; + let _access_lock = db.access_lock()?; + let seed = java_secret_bytes_at_least(env, &seed, "seed", PROTOCOL_FIELD_BYTES)?; + let round_id = java_string_to_rust(env, &round_id)?; + let hotkey = db + .generate_hotkey(&round_id, seed.expose_secret()) + .map_err(|e| anyhow!("generate_hotkey: {}", e))?; + update_round_phase_forward(&db, &round_id, RoundPhase::HotkeyGenerated)?; + make_jni_voting_hotkey(env, hotkey) + }); + unwrap_exc_or(&mut env, res, JObject::null().into_raw()) +} diff --git a/backend-lib/src/main/rust/voting/rounds.rs b/backend-lib/src/main/rust/voting/rounds.rs index 2dce8f909..fdf62bb2c 100644 --- a/backend-lib/src/main/rust/voting/rounds.rs +++ b/backend-lib/src/main/rust/voting/rounds.rs @@ -18,6 +18,7 @@ pub extern "C" fn Java_cash_z_ecc_android_sdk_internal_jni_VotingRustBackend_ini ) { let res = catch_unwind(&mut env, |env| { let db = db_from_handle(db_handle)?; + let _access_lock = db.access_lock()?; let params = voting::types::VotingRoundParams { vote_round_id: java_string_to_rust(env, &round_id)?, snapshot_height: jlong_to_u64(snapshot_height, "snapshot_height")?, @@ -49,6 +50,7 @@ pub extern "C" fn Java_cash_z_ecc_android_sdk_internal_jni_VotingRustBackend_get ) -> jobject { let res = catch_unwind(&mut env, |env| { let db = db_from_handle(db_handle)?; + let _access_lock = db.access_lock()?; let round_id = java_string_to_rust(env, &round_id)?; if !db .has_round(&round_id) @@ -75,6 +77,7 @@ pub extern "C" fn Java_cash_z_ecc_android_sdk_internal_jni_VotingRustBackend_lis ) -> jobjectArray { let res = catch_unwind(&mut env, |env| { let db = db_from_handle(db_handle)?; + let _access_lock = db.access_lock()?; let rounds = db .list_rounds() .map_err(|e| anyhow!("list_rounds: {}", e))?; @@ -83,6 +86,26 @@ pub extern "C" fn Java_cash_z_ecc_android_sdk_internal_jni_VotingRustBackend_lis unwrap_exc_or(&mut env, res, std::ptr::null_mut()) } +#[unsafe(no_mangle)] +pub extern "C" fn Java_cash_z_ecc_android_sdk_internal_jni_VotingRustBackend_getBundleCountNative< + 'local, +>( + mut env: JNIEnv<'local>, + _: JClass<'local>, + db_handle: jlong, + round_id: JString<'local>, +) -> jint { + let res = catch_unwind(&mut env, |env| { + let db = db_from_handle(db_handle)?; + let _access_lock = db.access_lock()?; + let count = db + .get_bundle_count(&java_string_to_rust(env, &round_id)?) + .map_err(|e| anyhow!("get_bundle_count: {}", e))?; + u32_to_jint(count, "bundle_count") + }); + unwrap_exc_or(&mut env, res, -1) +} + #[unsafe(no_mangle)] pub extern "C" fn Java_cash_z_ecc_android_sdk_internal_jni_VotingRustBackend_getVotesNative< 'local, @@ -94,6 +117,7 @@ pub extern "C" fn Java_cash_z_ecc_android_sdk_internal_jni_VotingRustBackend_get ) -> jobjectArray { let res = catch_unwind(&mut env, |env| { let db = db_from_handle(db_handle)?; + let _access_lock = db.access_lock()?; let votes = db .get_votes(&java_string_to_rust(env, &round_id)?) .map_err(|e| anyhow!("get_votes: {}", e))?; @@ -113,6 +137,7 @@ pub extern "C" fn Java_cash_z_ecc_android_sdk_internal_jni_VotingRustBackend_cle ) { let res = catch_unwind(&mut env, |env| { let db = db_from_handle(db_handle)?; + let _access_lock = db.access_lock()?; db.clear_round(&java_string_to_rust(env, &round_id)?) .map_err(|e| anyhow!("clear_round: {}", e))?; Ok(()) @@ -132,13 +157,14 @@ pub extern "C" fn Java_cash_z_ecc_android_sdk_internal_jni_VotingRustBackend_del ) -> jlong { let res = catch_unwind(&mut env, |env| { let db = db_from_handle(db_handle)?; + let _access_lock = db.access_lock()?; let deleted_rows = db .delete_skipped_bundles( &java_string_to_rust(env, &round_id)?, jint_to_u32(keep_count, "keep_count")?, ) .map_err(|e| anyhow!("delete_skipped_bundles: {}", e))?; - Ok(deleted_rows as jlong) + u64_to_jlong(deleted_rows, "deleted_rows") }); unwrap_exc_or(&mut env, res, -1) } diff --git a/backend-lib/src/main/rust/voting/util.rs b/backend-lib/src/main/rust/voting/util.rs new file mode 100644 index 000000000..826ac08e4 --- /dev/null +++ b/backend-lib/src/main/rust/voting/util.rs @@ -0,0 +1,15 @@ +use super::*; + +#[unsafe(no_mangle)] +pub extern "C" fn Java_cash_z_ecc_android_sdk_internal_jni_VotingRustBackend_warmProvingCachesNative< + 'local, +>( + mut env: JNIEnv<'local>, + _: JClass<'local>, +) { + let res = catch_unwind(&mut env, |_env| { + voting::warm_proving_caches(); + Ok(()) + }); + unwrap_exc_or(&mut env, res, ()) +} diff --git a/sdk-lib/src/main/java/cash/z/ecc/android/sdk/ext/BlockExt.kt b/sdk-lib/src/main/java/cash/z/ecc/android/sdk/ext/BlockExt.kt index fd093d701..5882bb740 100644 --- a/sdk-lib/src/main/java/cash/z/ecc/android/sdk/ext/BlockExt.kt +++ b/sdk-lib/src/main/java/cash/z/ecc/android/sdk/ext/BlockExt.kt @@ -2,24 +2,35 @@ package cash.z.ecc.android.sdk.ext import java.util.Locale +private const val HEX_CHARS_PER_BYTE = 2 +private const val HEX_RADIX = 16 + fun ByteArray.toHex(): String { - val sb = StringBuilder(size * 2) + val sb = StringBuilder(size * HEX_CHARS_PER_BYTE) for (b in this) { sb.append(String.format(Locale.ROOT, "%02x", b)) } return sb.toString() } -// Not used within the SDK, but is used by the Wallet app -@Suppress("unused", "MagicNumber") +@Suppress("MagicNumber") fun String.fromHex(): ByteArray { + require(length % HEX_CHARS_PER_BYTE == 0) { + "Hex string must have an even length, got $length" + } + val len = length - val data = ByteArray(len / 2) + val data = ByteArray(len / HEX_CHARS_PER_BYTE) var i = 0 while (i < len) { + val high = Character.digit(this[i], HEX_RADIX) + val low = Character.digit(this[i + 1], HEX_RADIX) + require(high >= 0 && low >= 0) { + "Invalid hex character at index $i" + } data[i / 2] = - ((Character.digit(this[i], 16) shl 4) + Character.digit(this[i + 1], 16)).toByte() - i += 2 + ((high shl 4) + low).toByte() + i += HEX_CHARS_PER_BYTE } return data } diff --git a/sdk-lib/src/main/java/cash/z/ecc/android/sdk/internal/TypesafeVotingBackend.kt b/sdk-lib/src/main/java/cash/z/ecc/android/sdk/internal/TypesafeVotingBackend.kt index e983b4aad..3d272ca76 100644 --- a/sdk-lib/src/main/java/cash/z/ecc/android/sdk/internal/TypesafeVotingBackend.kt +++ b/sdk-lib/src/main/java/cash/z/ecc/android/sdk/internal/TypesafeVotingBackend.kt @@ -1,16 +1,35 @@ package cash.z.ecc.android.sdk.internal +import cash.z.ecc.android.sdk.internal.model.voting.JniBundleSetupResult import cash.z.ecc.android.sdk.internal.model.voting.JniRoundState import cash.z.ecc.android.sdk.internal.model.voting.JniRoundSummary import cash.z.ecc.android.sdk.internal.model.voting.JniVoteRecord +import cash.z.ecc.android.sdk.internal.model.voting.JniVotingHotkey @Suppress("TooManyFunctions", "LongParameterList") -interface TypesafeVotingBackend { +internal interface TypesafeVotingBackend { suspend fun openVotingDb(dbPath: String, walletId: String): TypesafeVotingDb + + suspend fun computeShareNullifier( + voteCommitment: ByteArray, + shareIndex: Int, + blind: ByteArray + ): ByteArray + + suspend fun computeBundleSetup(notesJson: String): JniBundleSetupResult + + suspend fun warmProvingCaches() + + suspend fun extractPcztSighash(pcztBytes: ByteArray): ByteArray + + suspend fun extractSpendAuthSig( + signedPcztBytes: ByteArray, + actionIndex: Int + ): ByteArray } @Suppress("TooManyFunctions", "LongParameterList") -interface TypesafeVotingDb { +internal interface TypesafeVotingDb { suspend fun close() suspend fun initRound( @@ -26,6 +45,8 @@ interface TypesafeVotingDb { suspend fun listRounds(): List + suspend fun getBundleCount(roundId: String): Int + suspend fun getVotes(roundId: String): List suspend fun clearRound(roundId: String) @@ -34,4 +55,51 @@ interface TypesafeVotingDb { roundId: String, keepCount: Int ): Long + + suspend fun setupBundles( + roundId: String, + notesJson: String + ): JniBundleSetupResult + + suspend fun generateHotkey( + roundId: String, + seed: ByteArray + ): JniVotingHotkey + + suspend fun buildGovernancePczt( + roundId: String, + bundleIndex: Int, + ufvk: String, + networkId: Int, + accountIndex: Int, + notesJson: String, + walletSeed: ByteArray, + seedFingerprint: ByteArray, + roundName: String, + addressIndex: Int + ): GovernancePcztResult +} + +internal data class GovernancePcztResult( + val pcztBytes: ByteArray, + val rk: ByteArray, + val sighash: ByteArray, + val actionIndex: Int +) { + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (other !is GovernancePcztResult) return false + return pcztBytes.contentEquals(other.pcztBytes) && + rk.contentEquals(other.rk) && + sighash.contentEquals(other.sighash) && + actionIndex == other.actionIndex + } + + override fun hashCode(): Int { + var result = pcztBytes.contentHashCode() + result = 31 * result + rk.contentHashCode() + result = 31 * result + sighash.contentHashCode() + result = 31 * result + actionIndex + return result + } } diff --git a/sdk-lib/src/main/java/cash/z/ecc/android/sdk/internal/TypesafeVotingBackendImpl.kt b/sdk-lib/src/main/java/cash/z/ecc/android/sdk/internal/TypesafeVotingBackendImpl.kt index 3985ccb38..4c3e7a03c 100644 --- a/sdk-lib/src/main/java/cash/z/ecc/android/sdk/internal/TypesafeVotingBackendImpl.kt +++ b/sdk-lib/src/main/java/cash/z/ecc/android/sdk/internal/TypesafeVotingBackendImpl.kt @@ -1,20 +1,48 @@ package cash.z.ecc.android.sdk.internal +import cash.z.ecc.android.sdk.ext.fromHex import cash.z.ecc.android.sdk.internal.jni.VotingRustBackend +import cash.z.ecc.android.sdk.internal.model.voting.JniBundleSetupResult import cash.z.ecc.android.sdk.internal.model.voting.JniRoundState import cash.z.ecc.android.sdk.internal.model.voting.JniRoundSummary import cash.z.ecc.android.sdk.internal.model.voting.JniVoteRecord +import cash.z.ecc.android.sdk.internal.model.voting.JniVotingHotkey +import org.json.JSONObject + +private const val PCZT_HASH_BYTES = 32 @Suppress("TooManyFunctions", "LongParameterList") -class TypesafeVotingBackendImpl : TypesafeVotingBackend { +internal class TypesafeVotingBackendImpl : TypesafeVotingBackend { private val rustBackendLazy = SuspendingLazy { VotingRustBackend.new() } + override suspend fun computeShareNullifier( + voteCommitment: ByteArray, + shareIndex: Int, + blind: ByteArray + ): ByteArray = + rustBackend().computeShareNullifier(voteCommitment, shareIndex, blind) + override suspend fun openVotingDb(dbPath: String, walletId: String): TypesafeVotingDb = TypesafeVotingDbImpl(rustBackend().openVotingDb(dbPath, walletId)) + override suspend fun computeBundleSetup(notesJson: String): JniBundleSetupResult = + rustBackend().computeBundleSetup(notesJson) + + override suspend fun warmProvingCaches() = + rustBackend().warmProvingCaches() + + override suspend fun extractPcztSighash(pcztBytes: ByteArray): ByteArray = + rustBackend().extractPcztSighash(pcztBytes) + + override suspend fun extractSpendAuthSig( + signedPcztBytes: ByteArray, + actionIndex: Int + ): ByteArray = + rustBackend().extractSpendAuthSig(signedPcztBytes, actionIndex) + private suspend fun rustBackend() = rustBackendLazy.getInstance(Unit) } @@ -46,6 +74,9 @@ private class TypesafeVotingDbImpl( override suspend fun listRounds(): List = votingDb.listRounds().asList() + override suspend fun getBundleCount(roundId: String): Int = + votingDb.getBundleCount(roundId) + override suspend fun getVotes(roundId: String): List = votingDb.getVotes(roundId).asList() @@ -56,4 +87,67 @@ private class TypesafeVotingDbImpl( roundId: String, keepCount: Int ): Long = votingDb.deleteSkippedBundles(roundId, keepCount) + + override suspend fun setupBundles( + roundId: String, + notesJson: String + ): JniBundleSetupResult = + votingDb.setupBundles(roundId, notesJson) + + override suspend fun generateHotkey( + roundId: String, + seed: ByteArray + ): JniVotingHotkey = + votingDb.generateHotkey(roundId, seed) + + override suspend fun buildGovernancePczt( + roundId: String, + bundleIndex: Int, + ufvk: String, + networkId: Int, + accountIndex: Int, + notesJson: String, + walletSeed: ByteArray, + seedFingerprint: ByteArray, + roundName: String, + addressIndex: Int + ): GovernancePcztResult = + JSONObject( + votingDb.buildGovernancePcztJson( + roundId, + bundleIndex, + ufvk, + networkId, + accountIndex, + notesJson, + walletSeed, + seedFingerprint, + roundName, + addressIndex + ) + ).toGovernancePcztResult() +} + +private fun JSONObject.getCheckedInt(name: String): Int = + Math.toIntExact(getLong(name)) + +private fun JSONObject.toGovernancePcztResult() = + GovernancePcztResult( + pcztBytes = getHexBytes("pczt_bytes"), + rk = getHexBytes("rk", PCZT_HASH_BYTES), + sighash = getHexBytes("pczt_sighash", PCZT_HASH_BYTES), + actionIndex = getCheckedInt("action_index") + ) + +private fun JSONObject.getHexBytes( + name: String, + expectedSize: Int? = null +): ByteArray { + val bytes = getString(name).fromHex() + + require(expectedSize == null || bytes.size == expectedSize) { + "$name must be $expectedSize bytes, got ${bytes.size}" + } + + return bytes }