Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions backend-lib/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions backend-lib/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,14 @@ rust-analyzer = "0.0.1"
# The Android JNI boundary for voting code is `src/main/rust/voting.rs`.
# Its share-nullifier entry point forwards to
# `zcash_voting::share_tracking::compute_share_nullifier`:
# https://github.com/valargroup/zcash_voting/blob/zcash_voting-v0.5.3/zcash_voting/src/share_tracking.rs
# https://github.com/valargroup/zcash_voting/blob/zcash_voting-v0.5.9/zcash_voting/src/share_tracking.rs
#
# The transitive `voting-circuits` dependency contains the share-reveal circuit
# and the Poseidon-based `share_nullifier_hash` implementation:
# https://github.com/valargroup/voting-circuits/blob/v0.4.1/voting-circuits/src/share_reveal/circuit.rs
# Default features stay disabled so this foundation change does not pull in the
# optional client PIR, tree-sync, or networking stacks.
zcash_voting = { version = "0.5.3", default-features = false }
zcash_voting = { version = "0.5.9", default-features = false }

## Uncomment this to test librustzcash changes locally
#[patch.crates-io]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
package cash.z.ecc.android.sdk.internal.jni

import cash.z.ecc.android.sdk.internal.model.voting.JniRoundPhase
import kotlinx.coroutines.test.runTest
import org.junit.Test
import kotlin.io.path.createTempDirectory
import kotlin.test.assertContentEquals
import kotlin.test.assertEquals
import kotlin.test.assertFailsWith
import kotlin.test.assertFalse
import kotlin.test.assertNotNull
import kotlin.test.assertNull

@OptIn(ExperimentalStdlibApi::class)
@Suppress("MagicNumber")
class VotingRustBackendTest {
companion object {
private const val FIELD_BYTES = 32
Expand All @@ -17,6 +23,15 @@ class VotingRustBackendTest {
private val SHORT_FIELD = ByteArray(FIELD_BYTES - 1)
private val EXPECTED_NULLIFIER =
"8d6d97caa19a20e5e67e7cc24aaaa7beb72b4a513863f6adbe7b62ba1b1b0010".hexToByteArray()

private const val WALLET_ID = "wallet-1"
private const val OTHER_WALLET_ID = "wallet-2"
private const val ROUND_ID = "round-1"
private const val SNAPSHOT_HEIGHT = 123_456L
private const val SESSION_JSON = "{\"round\":\"one\"}"
private val EA_PK = ByteArray(FIELD_BYTES) { 3 }
private val NC_ROOT = ByteArray(FIELD_BYTES) { 4 }
private val NULLIFIER_IMT_ROOT = ByteArray(FIELD_BYTES) { 5 }
}

@Test
Expand Down Expand Up @@ -45,4 +60,106 @@ class VotingRustBackendTest {
backend.computeShareNullifier(VOTE_COMMITMENT, OUT_OF_RANGE_SHARE_INDEX, BLIND)
}
}

@Test
fun voting_db_round_state_round_trips() =
runTest {
val db = VotingRustBackend.new().openVotingDb(newDbPath(), WALLET_ID)
try {
assertNull(db.getRoundState(ROUND_ID))

db.initRound(
roundId = ROUND_ID,
snapshotHeight = SNAPSHOT_HEIGHT,
eaPK = EA_PK,
ncRoot = NC_ROOT,
nullifierIMTRoot = NULLIFIER_IMT_ROOT,
sessionJson = SESSION_JSON
)

val state = assertNotNull(db.getRoundState(ROUND_ID))
assertEquals(ROUND_ID, state.roundId)
assertEquals(JniRoundPhase.INITIALIZED.value, state.phase)
assertEquals(JniRoundPhase.INITIALIZED, state.roundPhase)
assertEquals(SNAPSHOT_HEIGHT, state.snapshotHeight)
assertNull(state.hotkeyAddress)
assertNull(state.delegatedWeight)
assertFalse(state.proofGenerated)

val rounds = db.listRounds()
assertEquals(1, rounds.size)
val round = rounds.single()
assertEquals(ROUND_ID, round.roundId)
assertEquals(JniRoundPhase.INITIALIZED.value, round.phase)
assertEquals(JniRoundPhase.INITIALIZED, round.roundPhase)
assertEquals(SNAPSHOT_HEIGHT, round.snapshotHeight)

assertEquals(emptyList(), db.getVotes(ROUND_ID).asList())

db.clearRound(ROUND_ID)
assertNull(db.getRoundState(ROUND_ID))
} finally {
db.close()
}
}

@Test
fun voting_db_keeps_wallet_state_isolated() =
runTest {
val dbPath = newDbPath()
val firstWallet = VotingRustBackend.new().openVotingDb(dbPath, WALLET_ID)
val secondWallet = VotingRustBackend.new().openVotingDb(dbPath, OTHER_WALLET_ID)
try {
firstWallet.initRound(
roundId = ROUND_ID,
snapshotHeight = SNAPSHOT_HEIGHT,
eaPK = EA_PK,
ncRoot = NC_ROOT,
nullifierIMTRoot = NULLIFIER_IMT_ROOT,
sessionJson = null
)

assertNotNull(firstWallet.getRoundState(ROUND_ID))
assertNull(secondWallet.getRoundState(ROUND_ID))
} finally {
firstWallet.close()
secondWallet.close()
}
}

@Test
fun voting_db_rejects_malformed_inputs_and_closed_handle() =
runTest {
val db = VotingRustBackend.new().openVotingDb(newDbPath(), WALLET_ID)

assertFailsWith<RuntimeException> {
db.initRound(
roundId = ROUND_ID,
snapshotHeight = -1,
eaPK = EA_PK,
ncRoot = NC_ROOT,
nullifierIMTRoot = NULLIFIER_IMT_ROOT,
sessionJson = null
)
}
assertFailsWith<RuntimeException> {
db.initRound(
roundId = ROUND_ID,
snapshotHeight = SNAPSHOT_HEIGHT,
eaPK = SHORT_FIELD,
ncRoot = NC_ROOT,
nullifierIMTRoot = NULLIFIER_IMT_ROOT,
sessionJson = null
)
}

db.close()
db.close()
assertFailsWith<IllegalStateException> {
db.getRoundState(ROUND_ID)
}
}

private fun newDbPath() =
createTempDirectory("voting-db-").resolve("voting.db").toFile().absolutePath
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,16 @@
package cash.z.ecc.android.sdk.internal.jni

import androidx.annotation.Keep
import cash.z.ecc.android.sdk.internal.model.voting.JniRoundState
import cash.z.ecc.android.sdk.internal.model.voting.JniRoundSummary
import cash.z.ecc.android.sdk.internal.model.voting.JniVoteRecord
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import kotlinx.coroutines.withContext

@Keep
@Suppress("TooManyFunctions", "LongParameterList")
class VotingRustBackend private constructor() {
@Throws(RuntimeException::class)
fun computeShareNullifier(
Expand All @@ -8,6 +19,88 @@ class VotingRustBackend private constructor() {
blind: ByteArray
): ByteArray = computeShareNullifierNative(voteCommitment, shareIndex, blind)

suspend fun openVotingDb(dbPath: String, walletId: String): VotingDb =
withContext(Dispatchers.IO) {
openVotingDbNative(dbPath, walletId).let { dbHandle ->
check(dbHandle != 0L) {
"openVotingDb failed for dbPath=$dbPath"
}
VotingDb(dbHandle)
}
}

@Suppress("TooManyFunctions", "LongParameterList")
class VotingDb internal constructor(
private var dbHandle: Long?
) {
private val accessMutex = Mutex()

suspend fun close() {
accessMutex.withLock {
dbHandle?.let { handle ->
withContext(Dispatchers.IO) {
closeVotingDbNative(handle)
}
dbHandle = null
}
}
}

@Throws(RuntimeException::class)
suspend fun initRound(
roundId: String,
snapshotHeight: Long,
eaPK: ByteArray,
ncRoot: ByteArray,
nullifierIMTRoot: ByteArray,
sessionJson: String?
) = withHandle { handle ->
initRoundNative(
handle,
roundId,
snapshotHeight,
eaPK,
ncRoot,
nullifierIMTRoot,
sessionJson
)
}

@Throws(RuntimeException::class)
suspend fun getRoundState(roundId: String): JniRoundState? =
withHandle { handle -> getRoundStateNative(handle, roundId) }

@Throws(RuntimeException::class)
suspend fun listRounds(): Array<JniRoundSummary> =
withHandle { handle -> listRoundsNative(handle) }

@Throws(RuntimeException::class)
suspend fun getVotes(roundId: String): Array<JniVoteRecord> =
withHandle { handle -> getVotesNative(handle, roundId) }

@Throws(RuntimeException::class)
suspend fun clearRound(roundId: String) =
withHandle { handle -> clearRoundNative(handle, roundId) }

@Throws(RuntimeException::class)
suspend fun deleteSkippedBundles(
roundId: String,
keepCount: Int
): Long =
withHandle { handle -> deleteSkippedBundlesNative(handle, roundId, keepCount) }

private suspend fun <T> withHandle(block: (Long) -> T): T =
accessMutex.withLock {
val handle =
checkNotNull(dbHandle) {
"Voting DB handle is closed"
}
withContext(Dispatchers.IO) {
block(handle)
}
}
}

companion object {
suspend fun new(): VotingRustBackend {
RustBackend.loadLibrary()
Expand All @@ -22,5 +115,49 @@ class VotingRustBackend private constructor() {
shareIndex: Int,
blind: ByteArray
): ByteArray

@JvmStatic
@Throws(RuntimeException::class)
private external fun openVotingDbNative(dbPath: String, walletId: String): Long

@JvmStatic
@Throws(RuntimeException::class)
private external fun closeVotingDbNative(dbHandle: Long)

@JvmStatic
@Throws(RuntimeException::class)
private external fun initRoundNative(
dbHandle: Long,
roundId: String,
snapshotHeight: Long,
eaPK: ByteArray,
ncRoot: ByteArray,
nullifierIMTRoot: ByteArray,
sessionJson: String?
)

@JvmStatic
@Throws(RuntimeException::class)
private external fun getRoundStateNative(dbHandle: Long, roundId: String): JniRoundState?

@JvmStatic
@Throws(RuntimeException::class)
private external fun listRoundsNative(dbHandle: Long): Array<JniRoundSummary>

@JvmStatic
@Throws(RuntimeException::class)
private external fun getVotesNative(dbHandle: Long, roundId: String): Array<JniVoteRecord>

@JvmStatic
@Throws(RuntimeException::class)
private external fun clearRoundNative(dbHandle: Long, roundId: String)

@JvmStatic
@Throws(RuntimeException::class)
private external fun deleteSkippedBundlesNative(
dbHandle: Long,
roundId: String,
keepCount: Int
): Long
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package cash.z.ecc.android.sdk.internal.model.voting

import androidx.annotation.Keep

// Must match PHASE_* constants in backend-lib/src/main/rust/voting/helpers.rs.
internal const val JNI_ROUND_PHASE_INITIALIZED = 0
internal const val JNI_ROUND_PHASE_HOTKEY_GENERATED = 1
internal const val JNI_ROUND_PHASE_DELEGATION_CONSTRUCTED = 2
internal const val JNI_ROUND_PHASE_DELEGATION_PROVED = 3
internal const val JNI_ROUND_PHASE_VOTE_READY = 4

@Keep
data class JniRoundState(
val roundId: String,
val phase: Int,
val snapshotHeight: Long,
val hotkeyAddress: String?,
val delegatedWeight: Long?,
val proofGenerated: Boolean
) {
val roundPhase = JniRoundPhase.fromInt(phase)
}

@Keep
enum class JniRoundPhase(
val value: Int
) {
INITIALIZED(JNI_ROUND_PHASE_INITIALIZED),
HOTKEY_GENERATED(JNI_ROUND_PHASE_HOTKEY_GENERATED),
DELEGATION_CONSTRUCTED(JNI_ROUND_PHASE_DELEGATION_CONSTRUCTED),
DELEGATION_PROVED(JNI_ROUND_PHASE_DELEGATION_PROVED),
VOTE_READY(JNI_ROUND_PHASE_VOTE_READY);

companion object {
fun fromInt(value: Int) =

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Forward compatibility risk

This throws if Rust adds a new RoundPhase variant before Kotlin is updated — a library update alone can crash the app. For an internal sdk-lib boundary it is defensible, but consider an UNKNOWN(-1) fallback or returning null until the phase set is stable.

@greg0x greg0x May 11, 2026

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an internal api contract between rust and kotlin. For this it's better to fail as soon as possible so the issue comes out during dev, failing silently.

The phase set can also be considered stable.

p.s.: I wish we could enjoy the comfort of UniFFI which solves this problem at generation time.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

p.s.: I wish we could enjoy the comfort of UniFFI which solves this problem at generation time.

@noop-sk Is this something we should track as a future improvement?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree

entries.firstOrNull { it.value == value }
?: error("Unknown round phase: $value")
}
}

@Keep
data class JniRoundSummary(
val roundId: String,
val phase: Int,
val snapshotHeight: Long,
val createdAt: Long
) {
val roundPhase = JniRoundPhase.fromInt(phase)
}

@Keep
data class JniVoteRecord(
val proposalId: Int,
val bundleIndex: Int,
val choice: Int,
val submitted: Boolean
)
Loading