diff --git a/CHANGELOG.md b/CHANGELOG.md index 096f5ff01..1054b1480 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,15 @@ and this library adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 # Unreleased +## Added +- PIR (Private Information Retrieval) integration for detecting Orchard note spendability without waiting for shard-tree scanning to complete. + - `SpendabilityBackend` / `SpendabilityTypes` — Swift wrappers for the nullifier and witness PIR FFI layer. + - `Synchronizer.checkWalletSpendability` — queries a PIR server to determine which notes have been spent, without revealing the wallet's notes to the server. + - `Synchronizer.fetchNoteWitnesses` — fetches Orchard note commitment witnesses from a PIR server, making notes spendable before the scanner catches up. + - `SDKFlags.pirCompleted` — lifecycle flag preserving spendable balance across sync restarts. +- `Proposal.PIRWitnessConfig` — attach to a `Proposal` via `proposal.pirWitnessConfig` to enable PIR witness fetching when the wallet is not fully synced. The SDK handles alignment and retry logic automatically. +- `createProposedTransactions` reads PIR configuration from the proposal itself, keeping the method signature clean. + # 2.4.9 - 2026-04-04 ## Checkpoints diff --git a/Package.resolved b/Package.resolved index cfd7ec42e..a6c0534d3 100644 --- a/Package.resolved +++ b/Package.resolved @@ -27,24 +27,6 @@ "version" : "1.2.1" } }, - { - "identity" : "swift-asn1", - "kind" : "remoteSourceControl", - "location" : "https://github.com/apple/swift-asn1.git", - "state" : { - "revision" : "810496cf121e525d660cd0ea89a758740476b85f", - "version" : "1.5.1" - } - }, - { - "identity" : "swift-async-algorithms", - "kind" : "remoteSourceControl", - "location" : "https://github.com/apple/swift-async-algorithms.git", - "state" : { - "revision" : "9d349bcc328ac3c31ce40e746b5882742a0d1272", - "version" : "1.1.3" - } - }, { "identity" : "swift-atomics", "kind" : "remoteSourceControl", @@ -54,15 +36,6 @@ "version" : "1.3.0" } }, - { - "identity" : "swift-certificates", - "kind" : "remoteSourceControl", - "location" : "https://github.com/apple/swift-certificates.git", - "state" : { - "revision" : "24ccdeeeed4dfaae7955fcac9dbf5489ed4f1a25", - "version" : "1.18.0" - } - }, { "identity" : "swift-collections", "kind" : "remoteSourceControl", @@ -72,15 +45,6 @@ "version" : "1.3.0" } }, - { - "identity" : "swift-crypto", - "kind" : "remoteSourceControl", - "location" : "https://github.com/apple/swift-crypto.git", - "state" : { - "revision" : "6f70fa9eab24c1fd982af18c281c4525d05e3095", - "version" : "4.2.0" - } - }, { "identity" : "swift-http-structured-headers", "kind" : "remoteSourceControl", @@ -171,15 +135,6 @@ "version" : "1.35.1" } }, - { - "identity" : "swift-service-lifecycle", - "kind" : "remoteSourceControl", - "location" : "https://github.com/swift-server/swift-service-lifecycle.git", - "state" : { - "revision" : "89888196dd79c61c50bca9a103d8114f32e1e598", - "version" : "2.10.1" - } - }, { "identity" : "swift-system", "kind" : "remoteSourceControl", diff --git a/Sources/ZcashLightClientKit/ClosureSynchronizer.swift b/Sources/ZcashLightClientKit/ClosureSynchronizer.swift index 75df4e05c..a6ceb7b9d 100644 --- a/Sources/ZcashLightClientKit/ClosureSynchronizer.swift +++ b/Sources/ZcashLightClientKit/ClosureSynchronizer.swift @@ -89,6 +89,13 @@ public protocol ClosureSynchronizer { /// proposal, indicating whether they were submitted to the network or if an error /// occurred. /// + /// - Parameters: + /// - proposal: The proposal for which to create transactions. Attach a + /// `Proposal.PIRWitnessConfig` via `proposal.pirWitnessConfig` to enable PIR witness + /// fetching when the wallet is not fully synced. + /// - spendingKey: The `UnifiedSpendingKey` for the account that controls the funds. + /// - completion: Completion handler. + /// /// If `prepare()` hasn't already been called since creation of the synchronizer instance /// or since the last wipe then this method throws `SynchronizerErrors.notPrepared`. func createProposedTransactions( diff --git a/Sources/ZcashLightClientKit/CombineSynchronizer.swift b/Sources/ZcashLightClientKit/CombineSynchronizer.swift index ca41eb8f5..fae022e86 100644 --- a/Sources/ZcashLightClientKit/CombineSynchronizer.swift +++ b/Sources/ZcashLightClientKit/CombineSynchronizer.swift @@ -87,6 +87,12 @@ public protocol CombineSynchronizer { /// /// If `prepare()` hasn't already been called since creation of the synchronizer instance /// or since the last wipe then this method throws `SynchronizerErrors.notPrepared`. + /// + /// - Parameters: + /// - proposal: The proposal for which to create transactions. Attach a + /// `Proposal.PIRWitnessConfig` via `proposal.pirWitnessConfig` to enable PIR witness + /// fetching when the wallet is not fully synced. + /// - spendingKey: The `UnifiedSpendingKey` for the account that controls the funds. func createProposedTransactions( proposal: Proposal, spendingKey: UnifiedSpendingKey diff --git a/Sources/ZcashLightClientKit/Model/Proposal.swift b/Sources/ZcashLightClientKit/Model/Proposal.swift index 38cccd380..47561c0b4 100644 --- a/Sources/ZcashLightClientKit/Model/Proposal.swift +++ b/Sources/ZcashLightClientKit/Model/Proposal.swift @@ -9,8 +9,29 @@ import Foundation /// A data structure that describes a series of transactions to be created. public struct Proposal: Equatable { + /// PIR witness configuration attached to a proposal. + /// + /// Set `serverURL` before calling `createProposedTransactions` so the SDK + /// can fetch Orchard witnesses from the PIR server when the wallet is not + /// fully synced. The SDK sets `usePIRWitnesses` internally based on sync + /// status and retry logic. + public struct PIRWitnessConfig: Equatable, Sendable { + public let serverURL: String + public internal(set) var usePIRWitnesses: Bool + + public init(serverURL: String) { + self.serverURL = serverURL + self.usePIRWitnesses = false + } + } + let inner: FfiProposal + /// Optional PIR witness configuration. When set, the SDK will use the + /// provided server URL to fetch Orchard witnesses if the wallet is not + /// fully synced, enabling spending before scanning completes. + public var pirWitnessConfig: PIRWitnessConfig? + /// Returns the number of transactions that this proposal will create. /// /// This is equal to the number of `TransactionSubmitResult`s that will be returned diff --git a/Sources/ZcashLightClientKit/Rust/Spendability/SpendabilityBackend.swift b/Sources/ZcashLightClientKit/Rust/Spendability/SpendabilityBackend.swift new file mode 100644 index 000000000..1c409fa5e --- /dev/null +++ b/Sources/ZcashLightClientKit/Rust/Spendability/SpendabilityBackend.swift @@ -0,0 +1,132 @@ +// Swift wrapper for the PIR C FFI (spendability.rs + witness.rs). +// Stateless — each call connects to the PIR server and returns. + +import Foundation +import libzcashlc + +// MARK: - Error + +public enum SpendabilityBackendError: LocalizedError, Equatable { + case rustError(String) + + public var errorDescription: String? { + switch self { + case .rustError(let message): + return "Spendability backend error: \(message)" + } + } +} + +// MARK: - SpendabilityBackend + +/// Wraps the PIR network FFI. Stateless — no DB handle, no persistent connection. +public struct SpendabilityBackend: Sendable { + public init() {} + + /// Check nullifiers against the PIR server. No database access. + /// + /// - Parameters: + /// - notes: Unspent notes with nullifiers (from phase 1 DB read). + /// - pirServerUrl: Base URL of the spend-server. + /// - progress: Optional progress callback (0.0..1.0). + /// - Returns: A `PIRNullifierCheckResult` with spent flags and server metadata. + public func checkNullifiersPIR( + notes: [PIRUnspentNote], + pirServerUrl: String, + progress: SpendabilityProgressHandler? + ) throws -> PIRNullifierCheckResult { + let urlBytes = [UInt8](pirServerUrl.utf8) + + let nullifiers: [[UInt8]] = notes.map { $0.nf } + let nullifiersJSON = try JSONEncoder().encode(nullifiers) + + var context = SpendabilityProgressContext(handler: progress) + + let ptr: UnsafeMutablePointer? = urlBytes.withUnsafeBufferPointer { urlBuf in + nullifiersJSON.withUnsafeBytes { nfBuf in + withUnsafeMutablePointer(to: &context) { ctxPtr in + let callback: (@convention(c) (Double, UnsafeMutableRawPointer?) -> Void)? = + progress != nil ? spendabilityProgressTrampoline : nil + return zcashlc_check_nullifiers_pir( + urlBuf.baseAddress, + UInt(urlBuf.count), + nfBuf.baseAddress?.assumingMemoryBound(to: UInt8.self), + UInt(nfBuf.count), + callback, + UnsafeMutableRawPointer(ctxPtr) + ) + } + } + } + + guard let ptr else { + throw SpendabilityBackendError.rustError(lastErrorMessage(fallback: "`checkNullifiersPIR` failed")) + } + defer { zcashlc_free_boxed_slice(ptr) } + + let data = Data(bytes: ptr.pointee.ptr, count: Int(ptr.pointee.len)) + return try JSONDecoder().decode(PIRNullifierCheckResult.self, from: data) + } + + /// Fetch note commitment witnesses from the PIR server. No database access. + /// + /// - Parameters: + /// - notes: Notes needing witnesses (from DB read). + /// - pirServerUrl: Base URL of the witness PIR server. + /// - progress: Optional progress callback (0.0..1.0). + /// - Returns: A `PIRWitnessResult` with witness data for each note. + public func fetchWitnesses( + notes: [PIRNotePosition], + pirServerUrl: String, + progress: SpendabilityProgressHandler? + ) throws -> PIRWitnessResult { + let urlBytes = [UInt8](pirServerUrl.utf8) + + struct PositionInput: Codable { + let note_id: Int64 + let position: UInt64 + } + + let positions = notes.map { PositionInput(note_id: $0.id, position: $0.position) } + let positionsJSON = try JSONEncoder().encode(positions) + + var context = SpendabilityProgressContext(handler: progress) + + let ptr: UnsafeMutablePointer? = urlBytes.withUnsafeBufferPointer { urlBuf in + positionsJSON.withUnsafeBytes { posBuf in + withUnsafeMutablePointer(to: &context) { ctxPtr in + let callback: (@convention(c) (Double, UnsafeMutableRawPointer?) -> Void)? = + progress != nil ? spendabilityProgressTrampoline : nil + return zcashlc_fetch_pir_witnesses( + urlBuf.baseAddress, + UInt(urlBuf.count), + posBuf.baseAddress?.assumingMemoryBound(to: UInt8.self), + UInt(posBuf.count), + callback, + UnsafeMutableRawPointer(ctxPtr) + ) + } + } + } + + guard let ptr else { + throw SpendabilityBackendError.rustError(lastErrorMessage(fallback: "`fetchWitnesses` failed")) + } + defer { zcashlc_free_boxed_slice(ptr) } + + let data = Data(bytes: ptr.pointee.ptr, count: Int(ptr.pointee.len)) + return try JSONDecoder().decode(PIRWitnessResult.self, from: data) + } +} + +// MARK: - Progress callback trampoline + +private struct SpendabilityProgressContext { + let handler: SpendabilityProgressHandler? +} + +private func spendabilityProgressTrampoline(progress: Double, context: UnsafeMutableRawPointer?) { + guard let context else { return } + let ctx = context.assumingMemoryBound(to: SpendabilityProgressContext.self).pointee + ctx.handler?(progress) +} diff --git a/Sources/ZcashLightClientKit/Rust/Spendability/SpendabilityTypes.swift b/Sources/ZcashLightClientKit/Rust/Spendability/SpendabilityTypes.swift new file mode 100644 index 000000000..62d71f313 --- /dev/null +++ b/Sources/ZcashLightClientKit/Rust/Spendability/SpendabilityTypes.swift @@ -0,0 +1,179 @@ +// Swift types matching the JSON serde types in spendability.rs and witness.rs. +// All types are Codable for JSON serialization across the FFI boundary. + +import Foundation + +// MARK: - Result + +/// Result of a spendability PIR check. +public struct SpendabilityResult: Codable, Sendable, Equatable { + /// Earliest block height covered by the PIR database. + public let earliestHeight: UInt64 + /// Latest block height covered by the PIR database. + public let latestHeight: UInt64 + /// Note IDs whose nullifiers were found in the PIR database (i.e. spent). + public let spentNoteIds: [Int64] + /// Total zatoshi value of notes found spent by PIR. + public let totalSpentValue: UInt64 + + enum CodingKeys: String, CodingKey { + case earliestHeight = "earliest_height" + case latestHeight = "latest_height" + case spentNoteIds = "spent_note_ids" + case totalSpentValue = "total_spent_value" + } + + /// Whether any notes were detected as spent by the PIR server. + /// When true, the wallet should skip witness PIR and fall back to standard scanning. + public var anySpent: Bool { !spentNoteIds.isEmpty } + + public init(earliestHeight: UInt64, latestHeight: UInt64, spentNoteIds: [Int64], totalSpentValue: UInt64) { + self.earliestHeight = earliestHeight + self.latestHeight = latestHeight + self.spentNoteIds = spentNoteIds + self.totalSpentValue = totalSpentValue + } +} + +// MARK: - Unspent note + +/// An unspent Orchard note with its nullifier, for PIR spend-checking. +public struct PIRUnspentNote: Codable, Sendable, Equatable { + public let id: Int64 + /// Raw nullifier bytes (32 bytes). + public let nf: [UInt8] + public let value: UInt64 + + public init(id: Int64, nf: [UInt8], value: UInt64) { + self.id = id + self.nf = nf + self.value = value + } +} + +// MARK: - Spend metadata + +/// Per-nullifier metadata returned by the PIR server when a nullifier is found spent. +public struct PIRSpendMetadata: Codable, Sendable, Equatable { + /// Block height at which the note was spent. + public let spendHeight: UInt32 + /// Global Orchard commitment-tree position of the first output in the spending transaction. + public let firstOutputPosition: UInt32 + /// Number of Orchard actions in the spending transaction. + public let actionCount: UInt8 + + enum CodingKeys: String, CodingKey { + case spendHeight = "spend_height" + case firstOutputPosition = "first_output_position" + case actionCount = "action_count" + } + + public init(spendHeight: UInt32, firstOutputPosition: UInt32, actionCount: UInt8) { + self.spendHeight = spendHeight + self.firstOutputPosition = firstOutputPosition + self.actionCount = actionCount + } +} + +// MARK: - Nullifier check result + +/// Result of checking nullifiers against the PIR server. +public struct PIRNullifierCheckResult: Codable, Sendable, Equatable { + public let earliestHeight: UInt64 + public let latestHeight: UInt64 + /// Parallel to the input nullifiers: non-nil = spent (with metadata), nil = not spent. + public let spent: [PIRSpendMetadata?] + + enum CodingKeys: String, CodingKey { + case earliestHeight = "earliest_height" + case latestHeight = "latest_height" + case spent + } + + public init(earliestHeight: UInt64, latestHeight: UInt64, spent: [PIRSpendMetadata?]) { + self.earliestHeight = earliestHeight + self.latestHeight = latestHeight + self.spent = spent + } +} + +// MARK: - Progress + +/// Closure type for spendability check progress reporting. +public typealias SpendabilityProgressHandler = @Sendable (Double) -> Void + +// MARK: - Note position (input to witness PIR) + +/// An Orchard note that needs a PIR witness: has a tree position but the shard +/// containing it is not fully scanned. +public struct PIRNotePosition: Codable, Sendable, Equatable { + public let id: Int64 + public let position: UInt64 + public let value: UInt64 + + public init(id: Int64, position: UInt64, value: UInt64) { + self.id = id + self.position = position + self.value = value + } +} + +// MARK: - Witness entry (output from PIR server / input to DB write) + +/// A PIR-obtained witness for a single note. Sibling hashes are hex-encoded +/// 32-byte values ordered leaf-to-root. +public struct PIRWitnessEntry: Codable, Sendable, Equatable { + public let noteId: Int64 + public let position: UInt64 + /// 32 sibling hashes, each a 64-char hex string (32 bytes). + public let siblings: [String] + public let anchorHeight: UInt64 + /// The tree root at `anchorHeight`, as a 64-char hex string. + public let anchorRoot: String + + enum CodingKeys: String, CodingKey { + case noteId = "note_id" + case position + case siblings + case anchorHeight = "anchor_height" + case anchorRoot = "anchor_root" + } + + public init( + noteId: Int64, + position: UInt64, + siblings: [String], + anchorHeight: UInt64, + anchorRoot: String + ) { + self.noteId = noteId + self.position = position + self.siblings = siblings + self.anchorHeight = anchorHeight + self.anchorRoot = anchorRoot + } +} + +// MARK: - Witness fetch result (from PIR server) + +/// Result of fetching witnesses from the PIR server. +public struct PIRWitnessResult: Codable, Sendable, Equatable { + public let witnesses: [PIRWitnessEntry] + + public init(witnesses: [PIRWitnessEntry]) { + self.witnesses = witnesses + } +} + +// MARK: - Orchestration result (returned to app layer) + +/// Result of `fetchNoteWitnesses` — notes for which witnesses were obtained. +public struct WitnessResult: Sendable, Equatable { + public let witnessedNoteIds: [Int64] + public let totalWitnessedValue: UInt64 + + public init(witnessedNoteIds: [Int64], totalWitnessedValue: UInt64) { + self.witnessedNoteIds = witnessedNoteIds + self.totalWitnessedValue = totalWitnessedValue + } +} diff --git a/Sources/ZcashLightClientKit/Rust/ZcashRustBackend.swift b/Sources/ZcashLightClientKit/Rust/ZcashRustBackend.swift index 4a5036e18..97dd5a31a 100644 --- a/Sources/ZcashLightClientKit/Rust/ZcashRustBackend.swift +++ b/Sources/ZcashLightClientKit/Rust/ZcashRustBackend.swift @@ -962,19 +962,15 @@ struct ZcashRustBackend: ZcashRustBackendWelding { if await !sdkFlags.chainTipUpdated { accountBalances.forEach { key, _ in if let accountBalance = accountBalances[key] { + let saplingBalance = PoolBalance( + spendableValue: .zero, + changePendingConfirmation: accountBalance.saplingBalance.changePendingConfirmation, + valuePendingSpendability: accountBalance.saplingBalance.valuePendingSpendability + + accountBalance.saplingBalance.spendableValue + ) accountBalances[key] = AccountBalance( - saplingBalance: PoolBalance( - spendableValue: .zero, - changePendingConfirmation: accountBalance.saplingBalance.changePendingConfirmation, - valuePendingSpendability: accountBalance.saplingBalance.valuePendingSpendability - + accountBalance.saplingBalance.spendableValue - ), - orchardBalance: PoolBalance( - spendableValue: .zero, - changePendingConfirmation: accountBalance.orchardBalance.changePendingConfirmation, - valuePendingSpendability: accountBalance.orchardBalance.valuePendingSpendability - + accountBalance.orchardBalance.spendableValue - ), + saplingBalance: saplingBalance, + orchardBalance: accountBalance.orchardBalance, unshielded: .zero, awaitingResolution: accountBalance.unshielded ) @@ -1091,7 +1087,8 @@ struct ZcashRustBackend: ZcashRustBackendWelding { @DBActor func createProposedTransactions( proposal: FfiProposal, - usk: UnifiedSpendingKey + usk: UnifiedSpendingKey, + usePIRWitnesses: Bool ) async throws -> [Data] { let proposalBytes = try proposal.serializedData(partial: false).bytes @@ -1108,7 +1105,8 @@ struct ZcashRustBackend: ZcashRustBackendWelding { spendParamsPath.1, outputParamsPath.0, outputParamsPath.1, - networkType.networkId + networkType.networkId, + usePIRWitnesses ) } } @@ -1285,6 +1283,94 @@ struct ZcashRustBackend: ZcashRustBackendWelding { ) } } + + // MARK: - PIR (serialized through @DBActor) + + @DBActor + func getUnspentOrchardNotesForPIR() async throws -> [PIRUnspentNote] { + let ptr = zcashlc_get_unspent_orchard_notes_for_pir( + dbData.0, + dbData.1, + networkType.networkId + ) + + guard let ptr else { + throw SpendabilityBackendError.rustError( + lastErrorMessage(fallback: "`getUnspentOrchardNotesForPIR` failed") + ) + } + defer { zcashlc_free_boxed_slice(ptr) } + + let data = Data(bytes: ptr.pointee.ptr, count: Int(ptr.pointee.len)) + return try JSONDecoder().decode([PIRUnspentNote].self, from: data) + } + + // MARK: - Witness PIR + + @DBActor + func getNotesNeedingPIRWitness() async throws -> [PIRNotePosition] { + let ptr = zcashlc_get_notes_needing_pir_witness( + dbData.0, + dbData.1, + networkType.networkId + ) + + guard let ptr else { + throw SpendabilityBackendError.rustError( + lastErrorMessage(fallback: "`getNotesNeedingPIRWitness` failed") + ) + } + defer { zcashlc_free_boxed_slice(ptr) } + + let data = Data(bytes: ptr.pointee.ptr, count: Int(ptr.pointee.len)) + return try JSONDecoder().decode([PIRNotePosition].self, from: data) + } + + @DBActor + func getPIRWitnessNotes(for proposal: FfiProposal) async throws -> [PIRNotePosition] { + let proposalBytes = try proposal.serializedData(partial: false).bytes + let ptr = proposalBytes.withUnsafeBufferPointer { proposalPtr in + zcashlc_get_pir_witness_notes_for_proposal( + dbData.0, + dbData.1, + proposalPtr.baseAddress, + UInt(proposalBytes.count), + networkType.networkId + ) + } + + guard let ptr else { + throw SpendabilityBackendError.rustError( + lastErrorMessage(fallback: "`getPIRWitnessNotes(for:)` failed") + ) + } + defer { zcashlc_free_boxed_slice(ptr) } + + let data = Data(bytes: ptr.pointee.ptr, count: Int(ptr.pointee.len)) + return try JSONDecoder().decode([PIRNotePosition].self, from: data) + } + + @DBActor + func insertPIRWitnesses(_ witnesses: [PIRWitnessEntry]) async throws { + let json = try JSONEncoder().encode(witnesses) + + let result = json.withUnsafeBytes { buf in + zcashlc_insert_pir_witnesses( + dbData.0, + dbData.1, + networkType.networkId, + buf.baseAddress?.assumingMemoryBound(to: UInt8.self), + UInt(buf.count) + ) + } + + guard result == 0 else { + throw SpendabilityBackendError.rustError( + lastErrorMessage(fallback: "`insertPIRWitnesses` failed") + ) + } + } + } private extension ZcashRustBackend { diff --git a/Sources/ZcashLightClientKit/Rust/ZcashRustBackendWelding.swift b/Sources/ZcashLightClientKit/Rust/ZcashRustBackendWelding.swift index c4fb0e85b..1e3067c5b 100644 --- a/Sources/ZcashLightClientKit/Rust/ZcashRustBackendWelding.swift +++ b/Sources/ZcashLightClientKit/Rust/ZcashRustBackendWelding.swift @@ -288,10 +288,13 @@ protocol ZcashRustBackendWelding { /// Creates a transaction from the given proposal. /// - Parameter proposal: the transaction proposal. /// - Parameter usk: `UnifiedSpendingKey` for the account that controls the funds to be spent. + /// - Parameter usePIRWitnesses: When `true`, Orchard witnesses are read from + /// PIR-stored data instead of the local ShardTree. /// - Throws: `rustCreateToAddress`. func createProposedTransactions( proposal: FfiProposal, - usk: UnifiedSpendingKey + usk: UnifiedSpendingKey, + usePIRWitnesses: Bool ) async throws -> [Data] /// Creates a partially-created (unsigned without proofs) transaction from the given proposal. @@ -391,4 +394,22 @@ protocol ZcashRustBackendWelding { /// Attempts to delete an account defined by UUID func deleteAccount(_ accountUUID: AccountUUID) async throws + + // MARK: - PIR (serialized through @DBActor, no standalone connections) + + /// Returns unspent Orchard notes with nullifiers for PIR spend-checking. + func getUnspentOrchardNotesForPIR() async throws -> [PIRUnspentNote] + + // MARK: - Witness PIR (serialized through @DBActor, no standalone connections) + + /// Returns canonical Orchard notes that should be considered for PIR + /// witness fetch or refresh. + func getNotesNeedingPIRWitness() async throws -> [PIRNotePosition] + + /// Returns Orchard notes selected by the provided proposal that may require + /// a PIR witness refresh before transaction construction. + func getPIRWitnessNotes(for proposal: FfiProposal) async throws -> [PIRNotePosition] + + /// Inserts PIR-obtained witnesses into the wallet DB. + func insertPIRWitnesses(_ witnesses: [PIRWitnessEntry]) async throws } diff --git a/Sources/ZcashLightClientKit/Synchronizer.swift b/Sources/ZcashLightClientKit/Synchronizer.swift index b72a85f31..049fef002 100644 --- a/Sources/ZcashLightClientKit/Synchronizer.swift +++ b/Sources/ZcashLightClientKit/Synchronizer.swift @@ -208,13 +208,16 @@ public protocol Synchronizer: AnyObject { /// Creates the transactions in the given proposal. /// - /// - Parameter proposal: the proposal for which to create transactions. - /// - Parameter spendingKey: the `UnifiedSpendingKey` associated with the account for which the proposal was created. - /// /// Returns a stream of objects for the transactions that were created as part of the /// proposal, indicating whether they were submitted to the network or if an error /// occurred. /// + /// - Parameters: + /// - proposal: The proposal for which to create transactions. Attach a + /// `Proposal.PIRWitnessConfig` via `proposal.pirWitnessConfig` to enable PIR witness + /// fetching when the wallet is not fully synced. + /// - spendingKey: The `UnifiedSpendingKey` for the account that controls the funds. + /// /// If `prepare()` hasn't already been called since creation of the synchronizer instance /// or since the last wipe then this method throws `SynchronizerErrors.notPrepared`. func createProposedTransactions( @@ -526,6 +529,34 @@ public protocol Synchronizer: AnyObject { /// /// - Throws rustDeleteAccount as a common indicator of the operation failure func deleteAccount(_ accountUUID: AccountUUID) async throws -> Void + + /// Check spendability of all unspent orchard notes in the wallet using a PIR server. + /// Queries the wallet DB for unspent notes, checks each via PIR, and returns + /// which are spent along with total spent value. + /// + /// - Parameters: + /// - pirServerUrl: Base URL of the spend-server. + /// - progress: Optional progress callback (0.0..1.0). + func checkWalletSpendability( + pirServerUrl: String, + progress: SpendabilityProgressHandler? + ) async throws -> SpendabilityResult + + /// Fetch note commitment witnesses from the PIR server for canonical Orchard + /// notes the wallet wants to keep spendable during sync. + /// + /// Orchestrates: DB read (canonical notes) -> PIR server fetch -> DB write + /// (store witnesses). Notes with witnesses bypass the shard-scanned gate in + /// coin selection, making them spendable sooner. + /// + /// - Parameters: + /// - pirServerUrl: Base URL of the witness PIR server. + /// - progress: Optional progress callback (0.0..1.0). + func fetchNoteWitnesses( + pirServerUrl: String, + progress: SpendabilityProgressHandler? + ) async throws -> WitnessResult + } public enum SyncStatus: Equatable { diff --git a/Sources/ZcashLightClientKit/Synchronizer/SDKSynchronizer.swift b/Sources/ZcashLightClientKit/Synchronizer/SDKSynchronizer.swift index db0aa0ee9..5f96fe764 100644 --- a/Sources/ZcashLightClientKit/Synchronizer/SDKSynchronizer.swift +++ b/Sources/ZcashLightClientKit/Synchronizer/SDKSynchronizer.swift @@ -15,6 +15,8 @@ public class SDKSynchronizer: Synchronizer { private enum Constants { static let fixWitnessesLastVersionCall = "ud_fixWitnessesLastVersionCall" } + + typealias PIRWitnessFetcher = @Sendable ([PIRNotePosition], String, SpendabilityProgressHandler?) throws -> PIRWitnessResult public var alias: ZcashSynchronizerAlias { initializer.alias } @@ -49,6 +51,7 @@ public class SDKSynchronizer: Synchronizer { public let network: ZcashNetwork private var transactionEncoder: TransactionEncoder private let transactionRepository: TransactionRepository + private let pirWitnessFetcher: PIRWitnessFetcher private let syncSessionIDGenerator: SyncSessionIDGenerator private let syncSession: SyncSession @@ -67,7 +70,14 @@ public class SDKSynchronizer: Synchronizer { initializer: initializer, walletBirthdayProvider: { initializer.walletBirthday } ), - syncSessionTicker: .live + syncSessionTicker: .live, + pirWitnessFetcher: { notes, pirServerUrl, progress in + try SpendabilityBackend().fetchWitnesses( + notes: notes, + pirServerUrl: pirServerUrl, + progress: progress + ) + } ) } @@ -77,13 +87,15 @@ public class SDKSynchronizer: Synchronizer { transactionEncoder: TransactionEncoder, transactionRepository: TransactionRepository, blockProcessor: CompactBlockProcessor, - syncSessionTicker: SessionTicker + syncSessionTicker: SessionTicker, + pirWitnessFetcher: @escaping PIRWitnessFetcher ) { self.connectionState = .idle self.underlyingStatus = GenericActor(status) self.initializer = initializer self.transactionEncoder = transactionEncoder self.transactionRepository = transactionRepository + self.pirWitnessFetcher = pirWitnessFetcher self.blockProcessor = blockProcessor self.network = initializer.network self.metrics = initializer.container.resolve(SDKMetrics.self) @@ -430,10 +442,33 @@ public class SDKSynchronizer: Synchronizer { logger: logger ) - let transactions = try await transactionEncoder.createProposedTransactions( - proposal: proposal, - spendingKey: spendingKey - ) + // If the wallet is syncing and PIR witnesses are enabled, align the proposal witnesses. + var proposal = proposal + let isSyncing = await status.isSyncing + if isSyncing && proposal.pirWitnessConfig?.usePIRWitnesses == true { + try await alignProposalWitnesses(proposal: proposal) + } + + logger.info("[PIR-DEBUG] createProposedTransactions: status=\(isSyncing), usePIRWitnesses=\(proposal.pirWitnessConfig?.usePIRWitnesses ?? false), pirServerURL=\(proposal.pirWitnessConfig?.serverURL ?? "nil")") + + let transactions: [ZcashTransaction.Overview] + do { + transactions = try await transactionEncoder.createProposedTransactions( + proposal: proposal, + spendingKey: spendingKey + ) + } catch { + logger.info("[PIR-DEBUG] createProposedTransactions: status=\(isSyncing), usePIRWitnesses=\(proposal.pirWitnessConfig?.usePIRWitnesses ?? false), pirServerURL=\(proposal.pirWitnessConfig?.serverURL ?? "nil")") + if try await refreshProposalWitnessesIfNeeded(after: error, proposal: proposal) { + proposal.pirWitnessConfig?.usePIRWitnesses = true + transactions = try await transactionEncoder.createProposedTransactions( + proposal: proposal, + spendingKey: spendingKey + ) + } else { + throw error + } + } return submitTransactions(transactions) } @@ -1028,7 +1063,7 @@ public class SDKSynchronizer: Synchronizer { public func debugDatabase(sql: String) -> String { transactionRepository.debugDatabase(sql: sql) } - + public func getSingleUseTransparentAddress(accountUUID: AccountUUID) async throws -> SingleUseTransparentAddress { try await initializer.rustBackend.getSingleUseTransparentAddress(accountUUID: accountUUID) } @@ -1091,6 +1126,161 @@ public class SDKSynchronizer: Synchronizer { try await initializer.rustBackend.deleteAccount(accountUUID) } + public func checkWalletSpendability( + pirServerUrl: String, + progress: SpendabilityProgressHandler? + ) async throws -> SpendabilityResult { + // Read unspent notes (@DBActor — serialized with sync) + let notes = try await initializer.rustBackend.getUnspentOrchardNotesForPIR() + guard !notes.isEmpty else { + return SpendabilityResult(earliestHeight: 0, latestHeight: 0, spentNoteIds: [], totalSpentValue: 0) + } + + // Check nullifiers against PIR server (detached — no DB connection held) + let checkResult = try await Task.detached(priority: .userInitiated) { + try SpendabilityBackend().checkNullifiersPIR( + notes: notes, + pirServerUrl: pirServerUrl, + progress: progress + ) + }.value + + let spentNotes = zip(notes, checkResult.spent).filter { $0.1 != nil } + let spentNoteIds = spentNotes.map(\.0.id) + let totalSpentValue = spentNotes.map(\.0.value).reduce(0, +) + + return SpendabilityResult( + earliestHeight: checkResult.earliestHeight, + latestHeight: checkResult.latestHeight, + spentNoteIds: spentNoteIds, + totalSpentValue: totalSpentValue + ) + } + + public func fetchNoteWitnesses( + pirServerUrl: String, + progress: SpendabilityProgressHandler? + ) async throws -> WitnessResult { + let notes = try await initializer.rustBackend.getNotesNeedingPIRWitness() + guard !notes.isEmpty else { + return WitnessResult(witnessedNoteIds: [], totalWitnessedValue: 0) + } + + let witnessResult = try await Task.detached(priority: .userInitiated) { + try SpendabilityBackend().fetchWitnesses( + notes: notes, + pirServerUrl: pirServerUrl, + progress: progress + ) + }.value + + guard !witnessResult.witnesses.isEmpty else { + return WitnessResult(witnessedNoteIds: [], totalWitnessedValue: 0) + } + + try await initializer.rustBackend.insertPIRWitnesses(witnessResult.witnesses) + + let witnessedIds = witnessResult.witnesses.map(\.noteId) + let totalValue = notes + .filter { witnessedIds.contains($0.id) } + .map(\.value) + .reduce(0, +) + + return WitnessResult( + witnessedNoteIds: witnessedIds, + totalWitnessedValue: totalValue + ) + } + + private func alignProposalWitnesses(proposal: Proposal) async throws { + let currentStatus = await status + if case .synced = currentStatus { return } + if proposal.pirWitnessConfig?.usePIRWitnesses == false { return } + + guard let serverURL = proposal.pirWitnessConfig?.serverURL else { return } + + let notes = try await initializer.rustBackend.getPIRWitnessNotes(for: proposal.inner) + guard !notes.isEmpty else { return } + + logger.debug("Aligning PIR witnesses for \(notes.count) Orchard notes before tx creation") + let inserted = try await fetchAndInsertPIRWitnesses(notes: notes, pirServerUrl: serverURL) + if inserted > 0 { + logger.debug("Aligned \(inserted) PIR witnesses to same anchor") + } else { + logger.debug("Witness alignment skipped: server returned no witnesses") + } + } + + /// Fetches PIR witnesses for the given notes and inserts canonical ones into the DB. + /// + /// Returns the number of witnesses inserted, or 0 if the server returned none. + @discardableResult + private func fetchAndInsertPIRWitnesses( + notes: [PIRNotePosition], + pirServerUrl: String + ) async throws -> Int { + let fetchWitnesses = pirWitnessFetcher + let witnessResult = try await Task.detached(priority: .userInitiated) { + try fetchWitnesses(notes, pirServerUrl, nil) + }.value + + let canonical = witnessResult.witnesses.filter { $0.noteId > 0 } + guard !canonical.isEmpty else { return 0 } + try await initializer.rustBackend.insertPIRWitnesses(canonical) + return canonical.count + } + + // Returns true only when the synchronizer refreshed and inserted replacement + // PIR witnesses for Orchard notes referenced by this proposal. The refresh is + // intentionally scoped to proposal-selected notes instead of all notes that + // need witnesses, and nil server URL / empty note sets / empty witness + // responses are treated as "no retry performed". + private func refreshProposalWitnessesIfNeeded( + after error: Error, + proposal: Proposal + ) async throws -> Bool { + guard isPIRWitnessMismatch(error) else { + return false + } + + if proposal.pirWitnessConfig?.usePIRWitnesses == false { return false } + + logger.debug("PIR witness retry triggered after error: \(error.localizedDescription)") + + guard let serverURL = proposal.pirWitnessConfig?.serverURL else { + logger.warn("PIR witness retry skipped: no witness server URL provided") + return false + } + + let notes = try await initializer.rustBackend.getPIRWitnessNotes(for: proposal.inner) + guard !notes.isEmpty else { + logger.debug("PIR witness retry skipped: proposal contains no Orchard PIR witness candidates") + return false + } + + logger.debug("Refreshing PIR witnesses for \(notes.count) proposal-selected Orchard notes") + let inserted = try await fetchAndInsertPIRWitnesses(notes: notes, pirServerUrl: serverURL) + guard inserted > 0 else { + logger.debug("PIR witness retry skipped: witness server returned no witnesses") + return false + } + + logger.debug("PIR witness retry fetched \(inserted) witnesses") + return true + } + + // This currently relies on Rust error message substrings. Keep the heuristic + // narrow for retry safety, but replace it with a structured error signal when + // the backend exposes one. + private func isPIRWitnessMismatch(_ error: Error) -> Bool { + guard case let ZcashError.rustCreateToAddress(rustError) = error else { + return false + } + + return rustError.contains("incompatible PIR witness anchors") + || rustError.contains("All anchors must be equal") + } + // MARK: Server switch public func switchTo(endpoint: LightWalletEndpoint) async throws { diff --git a/Sources/ZcashLightClientKit/Transaction/TransactionEncoder.swift b/Sources/ZcashLightClientKit/Transaction/TransactionEncoder.swift index e6d18155c..11384d4cc 100644 --- a/Sources/ZcashLightClientKit/Transaction/TransactionEncoder.swift +++ b/Sources/ZcashLightClientKit/Transaction/TransactionEncoder.swift @@ -59,7 +59,9 @@ protocol TransactionEncoder { /// Creates the transactions in the given proposal. /// - /// - Parameter proposal: the proposal for which to create transactions. + /// - Parameter proposal: the proposal for which to create transactions. The proposal's + /// `pirWitnessConfig.usePIRWitnesses` flag controls whether Orchard witnesses are + /// read from PIR-stored data instead of the local ShardTree. /// - Parameter spendingKey: the `UnifiedSpendingKey` associated with the account for which the proposal was created. /// - Throws: /// - `walletTransEncoderCreateTransactionMissingSaplingParams` if the sapling parameters aren't downloaded. diff --git a/Sources/ZcashLightClientKit/Transaction/WalletTransactionEncoder.swift b/Sources/ZcashLightClientKit/Transaction/WalletTransactionEncoder.swift index dbd12bbe2..06e3e1fbe 100644 --- a/Sources/ZcashLightClientKit/Transaction/WalletTransactionEncoder.swift +++ b/Sources/ZcashLightClientKit/Transaction/WalletTransactionEncoder.swift @@ -108,14 +108,19 @@ class WalletTransactionEncoder: TransactionEncoder { proposal: Proposal, spendingKey: UnifiedSpendingKey ) async throws -> [ZcashTransaction.Overview] { + let usePIRWitnesses = proposal.pirWitnessConfig?.usePIRWitnesses ?? false + logger.info("[PIR-DEBUG] WalletTransactionEncoder.createProposedTransactions: usePIRWitnesses=\(usePIRWitnesses)") + guard ensureParams(spend: self.spendParamsURL, output: self.outputParamsURL) else { throw ZcashError.walletTransEncoderCreateTransactionMissingSaplingParams } let txIds = try await rustBackend.createProposedTransactions( proposal: proposal.inner, - usk: spendingKey + usk: spendingKey, + usePIRWitnesses: usePIRWitnesses ) + logger.info("[PIR-DEBUG] WalletTransactionEncoder.createProposedTransactions completed, txIds count=\(txIds.count)") return try await fetchTransactionsForTxIds(txIds) } diff --git a/Sources/ZcashLightClientKit/Utils/SDKFlags.swift b/Sources/ZcashLightClientKit/Utils/SDKFlags.swift index ebacedef1..bf5487166 100644 --- a/Sources/ZcashLightClientKit/Utils/SDKFlags.swift +++ b/Sources/ZcashLightClientKit/Utils/SDKFlags.swift @@ -28,7 +28,12 @@ actor SDKFlags { /// Runtime helper flag used to mark whether chainTip CBP action has been done. var chainTipUpdated = false var chainTipUpdatedTimestamp: TimeInterval = 0.0 - + + /// Set after `checkWalletSpendability` succeeds. When true, Orchard + /// spendableValue is preserved even before `chainTipUpdated` is set. + var pirCompleted = false + var pirCompletedTimestamp: TimeInterval = 0.0 + init( torEnabled: Bool, exchangeRateEnabled: Bool @@ -64,11 +69,18 @@ actor SDKFlags { chainTipUpdatedTimestamp = Date().timeIntervalSince1970 } + /// Use to mark PIR spendability check as completed + func markPIRCompleted() { + pirCompleted = true + pirCompletedTimestamp = Date().timeIntervalSince1970 + } + /// The client using the SDK called `start()`. /// Use this to reset or update any relevant flags if needed. func sdkStarted() { + let now = Date().timeIntervalSince1970 // If chain tip has been updated recently and is set to false, re-enable it - if !chainTipUpdated && Date().timeIntervalSince1970 - chainTipUpdatedTimestamp < 120 { + if !chainTipUpdated && now - chainTipUpdatedTimestamp < 120 { chainTipUpdated = true } } diff --git a/Tests/OfflineTests/SpendabilityTypesTests.swift b/Tests/OfflineTests/SpendabilityTypesTests.swift new file mode 100644 index 000000000..253565582 --- /dev/null +++ b/Tests/OfflineTests/SpendabilityTypesTests.swift @@ -0,0 +1,869 @@ +// +// SpendabilityTypesTests.swift +// +// +// Tests for spendability and witness PIR types used across the FFI boundary, +// plus integration-level tests for the PIR witness retry and proactive +// alignment logic in SDKSynchronizer. +// + +import Foundation +@testable import TestUtils +import XCTest +@testable import ZcashLightClientKit + +final class SpendabilityTypesTests: ZcashTestCase { + let decoder = JSONDecoder() + let encoder = JSONEncoder() + + // MARK: - PIRUnspentNote + + func testPIRUnspentNoteDecodesFromRustJSON() throws { + let json = """ + {"id":42,"nf":[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31],"value":50000} + """.data(using: .utf8)! + + let note = try decoder.decode(PIRUnspentNote.self, from: json) + + XCTAssertEqual(note.id, 42) + XCTAssertEqual(note.nf, Array(0...31)) + XCTAssertEqual(note.value, 50_000) + } + + func testPIRUnspentNoteArrayDecodesFromRustJSON() throws { + let json = """ + [ + {"id":1,"nf":[170,170,170,170,170,170,170,170,170,170,170,170,170,170,170,170,170,170,170,170,170,170,170,170,170,170,170,170,170,170,170,170],"value":10000}, + {"id":2,"nf":[187,187,187,187,187,187,187,187,187,187,187,187,187,187,187,187,187,187,187,187,187,187,187,187,187,187,187,187,187,187,187,187],"value":20000} + ] + """.data(using: .utf8)! + + let notes = try decoder.decode([PIRUnspentNote].self, from: json) + + XCTAssertEqual(notes.count, 2) + XCTAssertEqual(notes[0].id, 1) + XCTAssertEqual(notes[0].nf, [UInt8](repeating: 0xAA, count: 32)) + XCTAssertEqual(notes[0].value, 10_000) + XCTAssertEqual(notes[1].id, 2) + XCTAssertEqual(notes[1].nf, [UInt8](repeating: 0xBB, count: 32)) + XCTAssertEqual(notes[1].value, 20_000) + } + + func testPIRUnspentNoteRoundTrip() throws { + let note = PIRUnspentNote(id: 7, nf: [UInt8](repeating: 0xFF, count: 32), value: 100_000) + let data = try encoder.encode(note) + let decoded = try decoder.decode(PIRUnspentNote.self, from: data) + + XCTAssertEqual(note, decoded) + } + + func testPIRUnspentNoteEmptyArray() throws { + let json = "[]".data(using: .utf8)! + let notes = try decoder.decode([PIRUnspentNote].self, from: json) + XCTAssertTrue(notes.isEmpty) + } + + // MARK: - PIRSpendMetadata + + func testPIRSpendMetadataDecodesFromRustJSON() throws { + let json = """ + {"spend_height":2800000,"first_output_position":12345678,"action_count":4} + """.data(using: .utf8)! + + let meta = try decoder.decode(PIRSpendMetadata.self, from: json) + + XCTAssertEqual(meta.spendHeight, 2_800_000) + XCTAssertEqual(meta.firstOutputPosition, 12_345_678) + XCTAssertEqual(meta.actionCount, 4) + } + + func testPIRSpendMetadataRoundTrip() throws { + let meta = PIRSpendMetadata(spendHeight: 100, firstOutputPosition: 5000, actionCount: 3) + let data = try encoder.encode(meta) + let decoded = try decoder.decode(PIRSpendMetadata.self, from: data) + + XCTAssertEqual(meta, decoded) + } + + func testPIRSpendMetadataEncodesSnakeCaseKeys() throws { + let meta = PIRSpendMetadata(spendHeight: 100, firstOutputPosition: 5000, actionCount: 3) + let data = try encoder.encode(meta) + let jsonObject = try JSONSerialization.jsonObject(with: data) as! [String: Any] + + XCTAssertNotNil(jsonObject["spend_height"]) + XCTAssertNotNil(jsonObject["first_output_position"]) + XCTAssertNotNil(jsonObject["action_count"]) + XCTAssertNil(jsonObject["spendHeight"], "Should not use camelCase key") + } + + // MARK: - PIRNullifierCheckResult + + func testPIRNullifierCheckResultDecodesFromRustJSON() throws { + let json = """ + {"earliest_height":100,"latest_height":200,"spent":[{"spend_height":150,"first_output_position":5000,"action_count":3},null,{"spend_height":180,"first_output_position":8000,"action_count":1}]} + """.data(using: .utf8)! + + let result = try decoder.decode(PIRNullifierCheckResult.self, from: json) + + XCTAssertEqual(result.earliestHeight, 100) + XCTAssertEqual(result.latestHeight, 200) + XCTAssertEqual(result.spent.count, 3) + XCTAssertEqual(result.spent[0]?.spendHeight, 150) + XCTAssertEqual(result.spent[0]?.firstOutputPosition, 5000) + XCTAssertEqual(result.spent[0]?.actionCount, 3) + XCTAssertNil(result.spent[1]) + XCTAssertEqual(result.spent[2]?.spendHeight, 180) + } + + func testPIRNullifierCheckResultEmptySpent() throws { + let json = """ + {"earliest_height":0,"latest_height":0,"spent":[]} + """.data(using: .utf8)! + + let result = try decoder.decode(PIRNullifierCheckResult.self, from: json) + + XCTAssertEqual(result.earliestHeight, 0) + XCTAssertEqual(result.latestHeight, 0) + XCTAssertTrue(result.spent.isEmpty) + } + + func testPIRNullifierCheckResultEncodesSnakeCaseKeys() throws { + let meta = PIRSpendMetadata(spendHeight: 100, firstOutputPosition: 5000, actionCount: 3) + let result = PIRNullifierCheckResult(earliestHeight: 500, latestHeight: 1000, spent: [nil, meta]) + let data = try encoder.encode(result) + let jsonObject = try JSONSerialization.jsonObject(with: data) as! [String: Any] + + XCTAssertNotNil(jsonObject["earliest_height"], "Expected snake_case key 'earliest_height'") + XCTAssertNotNil(jsonObject["latest_height"], "Expected snake_case key 'latest_height'") + XCTAssertNotNil(jsonObject["spent"]) + XCTAssertNil(jsonObject["earliestHeight"], "Should not use camelCase key") + } + + func testPIRNullifierCheckResultRoundTrip() throws { + let meta1 = PIRSpendMetadata(spendHeight: 100, firstOutputPosition: 5000, actionCount: 2) + let meta2 = PIRSpendMetadata(spendHeight: 200, firstOutputPosition: 9000, actionCount: 1) + let result = PIRNullifierCheckResult(earliestHeight: 42, latestHeight: 99, spent: [meta1, meta2, nil]) + let data = try encoder.encode(result) + let decoded = try decoder.decode(PIRNullifierCheckResult.self, from: data) + + XCTAssertEqual(result, decoded) + } + + // MARK: - SpendabilityResult + + func testSpendabilityResultDecodesFromRustJSON() throws { + let json = """ + {"earliest_height":100,"latest_height":200,"spent_note_ids":[1,3],"total_spent_value":50000} + """.data(using: .utf8)! + + let result = try decoder.decode(SpendabilityResult.self, from: json) + + XCTAssertEqual(result.earliestHeight, 100) + XCTAssertEqual(result.latestHeight, 200) + XCTAssertEqual(result.spentNoteIds, [1, 3]) + XCTAssertEqual(result.totalSpentValue, 50_000) + } + + func testSpendabilityResultEmpty() throws { + let result = SpendabilityResult(earliestHeight: 0, latestHeight: 0, spentNoteIds: [], totalSpentValue: 0) + let data = try encoder.encode(result) + let decoded = try decoder.decode(SpendabilityResult.self, from: data) + + XCTAssertEqual(result, decoded) + XCTAssertTrue(decoded.spentNoteIds.isEmpty) + XCTAssertEqual(decoded.totalSpentValue, 0) + } + + func testSpendabilityResultEncodesSnakeCaseKeys() throws { + let result = SpendabilityResult(earliestHeight: 1, latestHeight: 2, spentNoteIds: [5], totalSpentValue: 999) + let data = try encoder.encode(result) + let jsonObject = try JSONSerialization.jsonObject(with: data) as! [String: Any] + + XCTAssertNotNil(jsonObject["earliest_height"]) + XCTAssertNotNil(jsonObject["latest_height"]) + XCTAssertNotNil(jsonObject["spent_note_ids"]) + XCTAssertNotNil(jsonObject["total_spent_value"]) + } + + // MARK: - Cross-type consistency: notes → check → result pipeline + + func testThreePhasePipelineTypes() throws { + let meta1 = PIRSpendMetadata(spendHeight: 150, firstOutputPosition: 5000, actionCount: 3) + let meta3 = PIRSpendMetadata(spendHeight: 180, firstOutputPosition: 8000, actionCount: 1) + + let notes = [ + PIRUnspentNote(id: 1, nf: [UInt8](repeating: 0xAA, count: 32), value: 10_000), + PIRUnspentNote(id: 2, nf: [UInt8](repeating: 0xBB, count: 32), value: 20_000), + PIRUnspentNote(id: 3, nf: [UInt8](repeating: 0xCC, count: 32), value: 30_000) + ] + + let checkResult = PIRNullifierCheckResult( + earliestHeight: 100, + latestHeight: 200, + spent: [meta1, nil, meta3] + ) + + XCTAssertEqual(notes.count, checkResult.spent.count, "Spent entries must be parallel to notes") + + let spentNotes = zip(notes, checkResult.spent).filter { $0.1 != nil } + let spentNoteIds = spentNotes.map(\.0.id) + let totalSpentValue = spentNotes.map(\.0.value).reduce(0, +) + + XCTAssertEqual(spentNoteIds, [1, 3]) + XCTAssertEqual(totalSpentValue, 40_000) + + let finalResult = SpendabilityResult( + earliestHeight: checkResult.earliestHeight, + latestHeight: checkResult.latestHeight, + spentNoteIds: spentNoteIds, + totalSpentValue: totalSpentValue + ) + + XCTAssertEqual(finalResult.earliestHeight, 100) + XCTAssertEqual(finalResult.latestHeight, 200) + XCTAssertEqual(finalResult.spentNoteIds, [1, 3]) + XCTAssertEqual(finalResult.totalSpentValue, 40_000) + } + + func testThreePhasePipelineNoNotesSpent() throws { + let notes = [ + PIRUnspentNote(id: 1, nf: [UInt8](repeating: 0xAA, count: 32), value: 10_000), + PIRUnspentNote(id: 2, nf: [UInt8](repeating: 0xBB, count: 32), value: 20_000) + ] + + let checkResult = PIRNullifierCheckResult( + earliestHeight: 50, + latestHeight: 150, + spent: [nil, nil] + ) + + let spentNotes = zip(notes, checkResult.spent).filter { $0.1 != nil } + XCTAssertTrue(spentNotes.isEmpty) + XCTAssertEqual(spentNotes.map(\.0.value).reduce(0, +), 0) + } + + func testThreePhasePipelineAllNotesSpent() throws { + let meta1 = PIRSpendMetadata(spendHeight: 100, firstOutputPosition: 3000, actionCount: 2) + let meta2 = PIRSpendMetadata(spendHeight: 120, firstOutputPosition: 4000, actionCount: 1) + + let notes = [ + PIRUnspentNote(id: 1, nf: [UInt8](repeating: 0xAA, count: 32), value: 10_000), + PIRUnspentNote(id: 2, nf: [UInt8](repeating: 0xBB, count: 32), value: 20_000) + ] + + let checkResult = PIRNullifierCheckResult( + earliestHeight: 50, + latestHeight: 150, + spent: [meta1, meta2] + ) + + let spentNotes = zip(notes, checkResult.spent).filter { $0.1 != nil } + XCTAssertEqual(spentNotes.count, 2) + XCTAssertEqual(spentNotes.map(\.0.id), [1, 2]) + XCTAssertEqual(spentNotes.map(\.0.value).reduce(0, +), 30_000) + } + + // MARK: - PIRNotePosition + + func testPIRNotePositionDecodesFromRustJSON() throws { + let json = """ + {"id":42,"position":1000,"value":50000} + """.data(using: .utf8)! + + let note = try decoder.decode(PIRNotePosition.self, from: json) + + XCTAssertEqual(note.id, 42) + XCTAssertEqual(note.position, 1000) + XCTAssertEqual(note.value, 50_000) + } + + func testPIRNotePositionArrayDecodesFromRustJSON() throws { + let json = """ + [ + {"id":1,"position":100,"value":10000}, + {"id":2,"position":200,"value":20000} + ] + """.data(using: .utf8)! + + let notes = try decoder.decode([PIRNotePosition].self, from: json) + + XCTAssertEqual(notes.count, 2) + XCTAssertEqual(notes[0].id, 1) + XCTAssertEqual(notes[0].position, 100) + XCTAssertEqual(notes[0].value, 10_000) + XCTAssertEqual(notes[1].id, 2) + XCTAssertEqual(notes[1].position, 200) + XCTAssertEqual(notes[1].value, 20_000) + } + + func testPIRNotePositionRoundTrip() throws { + let note = PIRNotePosition(id: 7, position: 999, value: 100_000) + let data = try encoder.encode(note) + let decoded = try decoder.decode(PIRNotePosition.self, from: data) + + XCTAssertEqual(note, decoded) + } + + func testPIRNotePositionEmptyArray() throws { + let json = "[]".data(using: .utf8)! + let notes = try decoder.decode([PIRNotePosition].self, from: json) + XCTAssertTrue(notes.isEmpty) + } + + // MARK: - PIRWitnessEntry + + func testPIRWitnessEntryDecodesFromRustJSON() throws { + let sibling = String(repeating: "aa", count: 32) + let root = String(repeating: "bb", count: 32) + let json = """ + {"note_id":42,"position":1000,"siblings":["\(sibling)"],"anchor_height":3200000,"anchor_root":"\(root)"} + """.data(using: .utf8)! + + let entry = try decoder.decode(PIRWitnessEntry.self, from: json) + + XCTAssertEqual(entry.noteId, 42) + XCTAssertEqual(entry.position, 1000) + XCTAssertEqual(entry.siblings.count, 1) + XCTAssertEqual(entry.siblings[0], sibling) + XCTAssertEqual(entry.anchorHeight, 3_200_000) + XCTAssertEqual(entry.anchorRoot, root) + } + + func testPIRWitnessEntryEncodesSnakeCaseKeys() throws { + let entry = PIRWitnessEntry( + noteId: 1, + position: 500, + siblings: [String(repeating: "cc", count: 32)], + anchorHeight: 100, + anchorRoot: String(repeating: "dd", count: 32) + ) + let data = try encoder.encode(entry) + let jsonObject = try JSONSerialization.jsonObject(with: data) as! [String: Any] + + XCTAssertNotNil(jsonObject["note_id"], "Expected snake_case key 'note_id'") + XCTAssertNotNil(jsonObject["anchor_height"], "Expected snake_case key 'anchor_height'") + XCTAssertNotNil(jsonObject["anchor_root"], "Expected snake_case key 'anchor_root'") + XCTAssertNil(jsonObject["noteId"], "Should not use camelCase key") + XCTAssertNil(jsonObject["anchorHeight"], "Should not use camelCase key") + XCTAssertNil(jsonObject["anchorRoot"], "Should not use camelCase key") + } + + func testPIRWitnessEntryRoundTrip() throws { + let siblings = (0..<32).map { _ in String(repeating: "ab", count: 32) } + let entry = PIRWitnessEntry( + noteId: 99, + position: 12345, + siblings: siblings, + anchorHeight: 3_200_000, + anchorRoot: String(repeating: "ff", count: 32) + ) + let data = try encoder.encode(entry) + let decoded = try decoder.decode(PIRWitnessEntry.self, from: data) + + XCTAssertEqual(entry, decoded) + } + + // MARK: - PIRWitnessResult + + func testPIRWitnessResultDecodesFromRustJSON() throws { + let sibling = String(repeating: "aa", count: 32) + let root = String(repeating: "bb", count: 32) + let json = """ + {"witnesses":[{"note_id":42,"position":1000,"siblings":["\(sibling)"],"anchor_height":3200000,"anchor_root":"\(root)"}]} + """.data(using: .utf8)! + + let result = try decoder.decode(PIRWitnessResult.self, from: json) + + XCTAssertEqual(result.witnesses.count, 1) + XCTAssertEqual(result.witnesses[0].noteId, 42) + XCTAssertEqual(result.witnesses[0].anchorRoot, root) + } + + func testPIRWitnessResultEmpty() throws { + let json = """ + {"witnesses":[]} + """.data(using: .utf8)! + + let result = try decoder.decode(PIRWitnessResult.self, from: json) + XCTAssertTrue(result.witnesses.isEmpty) + } + + func testPIRWitnessResultRoundTrip() throws { + let result = PIRWitnessResult(witnesses: [ + PIRWitnessEntry( + noteId: 1, + position: 100, + siblings: [String(repeating: "aa", count: 32)], + anchorHeight: 500, + anchorRoot: String(repeating: "bb", count: 32) + ) + ]) + let data = try encoder.encode(result) + let decoded = try decoder.decode(PIRWitnessResult.self, from: data) + + XCTAssertEqual(result, decoded) + } + + // MARK: - WitnessResult (in-process only, not Codable) + + func testWitnessResultEquality() { + let a = WitnessResult(witnessedNoteIds: [1, 2, 3], totalWitnessedValue: 30_000) + let b = WitnessResult(witnessedNoteIds: [1, 2, 3], totalWitnessedValue: 30_000) + let c = WitnessResult(witnessedNoteIds: [1, 2], totalWitnessedValue: 20_000) + + XCTAssertEqual(a, b) + XCTAssertNotEqual(a, c) + } + + func testWitnessResultEmpty() { + let result = WitnessResult(witnessedNoteIds: [], totalWitnessedValue: 0) + + XCTAssertTrue(result.witnessedNoteIds.isEmpty) + XCTAssertEqual(result.totalWitnessedValue, 0) + } + + // MARK: - PIR witness retry + + private final class RetryTestTransactionEncoder: TransactionEncoder { + enum StubbedResult { + case success([ZcashTransaction.Overview]) + case failure(Error) + } + + private(set) var createProposedTransactionsCallsCount = 0 + private(set) var usePIRWitnessesHistory: [Bool] = [] + var createResults: [StubbedResult] = [] + + func createProposedTransactions( + proposal: Proposal, + spendingKey: UnifiedSpendingKey + ) async throws -> [ZcashTransaction.Overview] { + createProposedTransactionsCallsCount += 1 + usePIRWitnessesHistory.append(proposal.pirWitnessConfig?.usePIRWitnesses ?? false) + let index = createProposedTransactionsCallsCount - 1 + guard createResults.indices.contains(index) else { + XCTFail("Missing stubbed result for createProposedTransactions call \(index + 1)") + return [] + } + + switch createResults[index] { + case .success(let transactions): + return transactions + case .failure(let error): + throw error + } + } + + func proposeTransfer( + accountUUID: AccountUUID, + recipient: String, + amount: Zatoshi, + memoBytes: MemoBytes? + ) async throws -> Proposal { + fatalError("Unused in PIR witness retry tests") + } + + func proposeShielding( + accountUUID: AccountUUID, + shieldingThreshold: Zatoshi, + memoBytes: MemoBytes?, + transparentReceiver: String? + ) async throws -> Proposal? { + fatalError("Unused in PIR witness retry tests") + } + + func proposeFulfillingPaymentFromURI( + _ uri: String, + accountUUID: AccountUUID + ) async throws -> Proposal { + fatalError("Unused in PIR witness retry tests") + } + + func submit(transaction: EncodedTransaction) async throws { + fatalError("Unused in PIR witness retry tests") + } + + func fetchTransactionsForTxIds(_ txIds: [Data]) async throws -> [ZcashTransaction.Overview] { + fatalError("Unused in PIR witness retry tests") + } + + func closeDBConnection() {} + } + + private func makeSpendingKey(network: ZcashNetwork) throws -> UnifiedSpendingKey { + let derivationTool = DerivationTool(networkType: network.networkType) + return try derivationTool.deriveUnifiedSpendingKey( + seed: Environment.seedBytes, + accountIndex: Zip32AccountIndex(0) + ) + } + + private func makeProposal() -> Proposal { + Proposal(inner: FfiProposal()) + } + + private func makeNotePosition() -> PIRNotePosition { + PIRNotePosition(id: 1, position: 42, value: 60_000) + } + + private func makeWitnessEntry() -> PIRWitnessEntry { + PIRWitnessEntry( + noteId: 1, + position: 42, + siblings: Array(repeating: String(repeating: "00", count: 32), count: 32), + anchorHeight: 1_000, + anchorRoot: String(repeating: "11", count: 32) + ) + } + + private func makeSynchronizer( + rustBackend: ZcashRustBackendWeldingMock, + transactionEncoder: RetryTestTransactionEncoder, + syncStatus: InternalSyncStatus = .synced, + pirWitnessFetcher: @escaping SDKSynchronizer.PIRWitnessFetcher = { _, _, _ in + preconditionFailure("Unexpected PIR witness fetch") + } + ) async throws -> SDKSynchronizer { + let network = ZcashNetworkBuilder.network(for: .testnet) + mockContainer.mock(type: ZcashRustBackendWelding.self, isSingleton: true) { _ in rustBackend } + + let initializer = Initializer( + container: mockContainer, + cacheDbURL: nil, + fsBlockDbRoot: testTempDirectory, + generalStorageURL: testGeneralStorageDirectory, + dataDbURL: testTempDirectory.appendingPathComponent("data.db"), + torDirURL: testTempDirectory.appendingPathComponent("tor"), + endpoint: LightWalletEndpointBuilder.default, + network: network, + spendParamsURL: SaplingParamsSourceURL.tests.spendParamFileURL, + outputParamsURL: SaplingParamsSourceURL.tests.outputParamFileURL, + saplingParamsSourceURL: .tests, + alias: .default, + loggingPolicy: .noLogging, + isTorEnabled: false, + isExchangeRateEnabled: false + ) + + let blockProcessor = CompactBlockProcessor( + initializer: initializer, + walletBirthdayProvider: { 1 } + ) + let synchronizer = SDKSynchronizer( + status: .unprepared, + initializer: initializer, + transactionEncoder: transactionEncoder, + transactionRepository: initializer.transactionRepository, + blockProcessor: blockProcessor, + syncSessionTicker: .live, + pirWitnessFetcher: pirWitnessFetcher + ) + await synchronizer.updateStatus(syncStatus, updateExternalStatus: false) + + return synchronizer + } + + func testCreateProposedTransactionsRetriesOnceAfterPIRMismatch() async throws { + let rustBackend = ZcashRustBackendWeldingMock() + let transactionEncoder = RetryTestTransactionEncoder() + transactionEncoder.createResults = [ + .failure(ZcashError.rustCreateToAddress("Selected Orchard inputs were backed by incompatible PIR witness anchors.")), + .success([]) + ] + let note = makeNotePosition() + rustBackend.getPIRWitnessNotesReturnValue = [note] + + let witnessEntry = makeWitnessEntry() + let synchronizer = try await makeSynchronizer( + rustBackend: rustBackend, + transactionEncoder: transactionEncoder, + pirWitnessFetcher: { _, _, _ in + return PIRWitnessResult(witnesses: [witnessEntry]) + } + ) + + var proposal = makeProposal() + proposal.pirWitnessConfig = Proposal.PIRWitnessConfig(serverURL: "http://localhost:8080") + let stream = try await synchronizer.createProposedTransactions( + proposal: proposal, + spendingKey: try makeSpendingKey(network: synchronizer.network) + ) + + var iterator = stream.makeAsyncIterator() + let next = try await iterator.next() + XCTAssertNil(next) + XCTAssertEqual(transactionEncoder.createProposedTransactionsCallsCount, 2) + XCTAssertEqual(transactionEncoder.usePIRWitnessesHistory, [false, true]) + XCTAssertEqual(rustBackend.getPIRWitnessNotesCallsCount, 1) + XCTAssertEqual(rustBackend.insertPIRWitnessesCallsCount, 1) + XCTAssertEqual(rustBackend.insertPIRWitnessesReceivedWitnesses, [witnessEntry]) + } + + func testCreateProposedTransactionsDoesNotRetryForNonPIRFailure() async throws { + let rustBackend = ZcashRustBackendWeldingMock() + let transactionEncoder = RetryTestTransactionEncoder() + transactionEncoder.createResults = [ + .failure(ZcashError.rustCreateToAddress("proposal construction failed")) + ] + let synchronizer = try await makeSynchronizer( + rustBackend: rustBackend, + transactionEncoder: transactionEncoder + ) + + do { + var proposal = makeProposal() + proposal.pirWitnessConfig = Proposal.PIRWitnessConfig(serverURL: "http://localhost:8080") + _ = try await synchronizer.createProposedTransactions( + proposal: proposal, + spendingKey: try makeSpendingKey(network: synchronizer.network) + ) + XCTFail("Expected transaction creation to fail") + } catch let ZcashError.rustCreateToAddress(message) { + XCTAssertEqual(message, "proposal construction failed") + } catch { + XCTFail("Unexpected error: \(error)") + } + + XCTAssertEqual(transactionEncoder.createProposedTransactionsCallsCount, 1) + XCTAssertEqual(transactionEncoder.usePIRWitnessesHistory, [false]) + XCTAssertEqual(rustBackend.getPIRWitnessNotesCallsCount, 0) + XCTAssertEqual(rustBackend.insertPIRWitnessesCallsCount, 0) + } + + func testCreateProposedTransactionsDoesNotRetryWithoutWitnessServerURL() async throws { + let rustBackend = ZcashRustBackendWeldingMock() + let transactionEncoder = RetryTestTransactionEncoder() + transactionEncoder.createResults = [ + .failure(ZcashError.rustCreateToAddress("Selected Orchard inputs were backed by incompatible PIR witness anchors.")) + ] + let synchronizer = try await makeSynchronizer( + rustBackend: rustBackend, + transactionEncoder: transactionEncoder + ) + + do { + _ = try await synchronizer.createProposedTransactions( + proposal: makeProposal(), + spendingKey: try makeSpendingKey(network: synchronizer.network) + ) + XCTFail("Expected transaction creation to fail") + } catch let ZcashError.rustCreateToAddress(message) { + XCTAssertEqual(message, "Selected Orchard inputs were backed by incompatible PIR witness anchors.") + } catch { + XCTFail("Unexpected error: \(error)") + } + + XCTAssertEqual(transactionEncoder.createProposedTransactionsCallsCount, 1) + XCTAssertEqual(transactionEncoder.usePIRWitnessesHistory, [false]) + XCTAssertEqual(rustBackend.getPIRWitnessNotesCallsCount, 0) + XCTAssertEqual(rustBackend.insertPIRWitnessesCallsCount, 0) + } + + func testCreateProposedTransactionsDoesNotRetryWhenProposalHasNoPIRWitnessNotes() async throws { + let rustBackend = ZcashRustBackendWeldingMock() + let transactionEncoder = RetryTestTransactionEncoder() + transactionEncoder.createResults = [ + .failure(ZcashError.rustCreateToAddress("All anchors must be equal")) + ] + rustBackend.getPIRWitnessNotesReturnValue = [] + + let synchronizer = try await makeSynchronizer( + rustBackend: rustBackend, + transactionEncoder: transactionEncoder + ) + + do { + var proposal = makeProposal() + proposal.pirWitnessConfig = Proposal.PIRWitnessConfig(serverURL: "http://localhost:8080") + _ = try await synchronizer.createProposedTransactions( + proposal: proposal, + spendingKey: try makeSpendingKey(network: synchronizer.network) + ) + XCTFail("Expected transaction creation to fail") + } catch let ZcashError.rustCreateToAddress(message) { + XCTAssertEqual(message, "All anchors must be equal") + } catch { + XCTFail("Unexpected error: \(error)") + } + + XCTAssertEqual(transactionEncoder.createProposedTransactionsCallsCount, 1) + XCTAssertEqual(transactionEncoder.usePIRWitnessesHistory, [false]) + XCTAssertEqual(rustBackend.getPIRWitnessNotesCallsCount, 1) + XCTAssertEqual(rustBackend.insertPIRWitnessesCallsCount, 0) + } + + func testCreateProposedTransactionsDoesNotLoopAfterSecondFailure() async throws { + let rustBackend = ZcashRustBackendWeldingMock() + let transactionEncoder = RetryTestTransactionEncoder() + let witnessEntry = makeWitnessEntry() + transactionEncoder.createResults = [ + .failure(ZcashError.rustCreateToAddress("Selected Orchard inputs were backed by incompatible PIR witness anchors.")), + .failure(ZcashError.rustCreateToAddress("All anchors must be equal")) + ] + rustBackend.getPIRWitnessNotesReturnValue = [makeNotePosition()] + + let synchronizer = try await makeSynchronizer( + rustBackend: rustBackend, + transactionEncoder: transactionEncoder, + pirWitnessFetcher: { _, _, _ in + PIRWitnessResult(witnesses: [witnessEntry]) + } + ) + + do { + var proposal = makeProposal() + proposal.pirWitnessConfig = Proposal.PIRWitnessConfig(serverURL: "http://localhost:8080") + _ = try await synchronizer.createProposedTransactions( + proposal: proposal, + spendingKey: try makeSpendingKey(network: synchronizer.network) + ) + XCTFail("Expected second transaction creation attempt to fail") + } catch let ZcashError.rustCreateToAddress(message) { + XCTAssertEqual(message, "All anchors must be equal") + } catch { + XCTFail("Unexpected error: \(error)") + } + + XCTAssertEqual(transactionEncoder.createProposedTransactionsCallsCount, 2) + XCTAssertEqual(transactionEncoder.usePIRWitnessesHistory, [false, true]) + XCTAssertEqual(rustBackend.getPIRWitnessNotesCallsCount, 1) + XCTAssertEqual(rustBackend.insertPIRWitnessesCallsCount, 1) + } + + // MARK: - Proactive alignment + + func testProactiveAlignmentFetchesWitnessesWhenSyncing() async throws { + let rustBackend = ZcashRustBackendWeldingMock() + let transactionEncoder = RetryTestTransactionEncoder() + transactionEncoder.createResults = [.success([])] + + let note = makeNotePosition() + rustBackend.getPIRWitnessNotesReturnValue = [note] + + let witnessEntry = makeWitnessEntry() + var fetchCount = 0 + let synchronizer = try await makeSynchronizer( + rustBackend: rustBackend, + transactionEncoder: transactionEncoder, + syncStatus: .syncing(0.5, false), + pirWitnessFetcher: { _, _, _ in + fetchCount += 1 + return PIRWitnessResult(witnesses: [witnessEntry]) + } + ) + + var proposal = makeProposal() + proposal.pirWitnessConfig = Proposal.PIRWitnessConfig(serverURL: "http://localhost:8080") + let stream = try await synchronizer.createProposedTransactions( + proposal: proposal, + spendingKey: try makeSpendingKey(network: synchronizer.network) + ) + + var iterator = stream.makeAsyncIterator() + let next = try await iterator.next() + XCTAssertNil(next) + XCTAssertEqual(fetchCount, 1, "Proactive alignment should fetch witnesses once") + XCTAssertEqual(transactionEncoder.createProposedTransactionsCallsCount, 1) + XCTAssertEqual(transactionEncoder.usePIRWitnessesHistory, [true]) + XCTAssertEqual(rustBackend.insertPIRWitnessesCallsCount, 1) + } + + func testProactiveAlignmentSkippedWhenSynced() async throws { + let rustBackend = ZcashRustBackendWeldingMock() + let transactionEncoder = RetryTestTransactionEncoder() + transactionEncoder.createResults = [.success([])] + rustBackend.getPIRWitnessNotesReturnValue = [makeNotePosition()] + + let synchronizer = try await makeSynchronizer( + rustBackend: rustBackend, + transactionEncoder: transactionEncoder, + syncStatus: .synced + ) + + var proposal = makeProposal() + proposal.pirWitnessConfig = Proposal.PIRWitnessConfig(serverURL: "http://localhost:8080") + let stream = try await synchronizer.createProposedTransactions( + proposal: proposal, + spendingKey: try makeSpendingKey(network: synchronizer.network) + ) + + var iterator = stream.makeAsyncIterator() + _ = try await iterator.next() + XCTAssertEqual(transactionEncoder.usePIRWitnessesHistory, [false]) + XCTAssertEqual(rustBackend.getPIRWitnessNotesCallsCount, 0, "Should not query notes when synced") + XCTAssertEqual(rustBackend.insertPIRWitnessesCallsCount, 0) + } + + func testProactiveAlignmentSkippedWithoutWitnessServerURL() async throws { + let rustBackend = ZcashRustBackendWeldingMock() + let transactionEncoder = RetryTestTransactionEncoder() + transactionEncoder.createResults = [.success([])] + + let synchronizer = try await makeSynchronizer( + rustBackend: rustBackend, + transactionEncoder: transactionEncoder, + syncStatus: .syncing(0.3, false) + ) + + let stream = try await synchronizer.createProposedTransactions( + proposal: makeProposal(), + spendingKey: try makeSpendingKey(network: synchronizer.network) + ) + + var iterator = stream.makeAsyncIterator() + _ = try await iterator.next() + XCTAssertEqual(transactionEncoder.usePIRWitnessesHistory, [false], "Should not use PIR witnesses without witness URL") + XCTAssertEqual(rustBackend.getPIRWitnessNotesCallsCount, 0, "Should not query notes without witness URL") + } + + func testWitnessInsertionIgnoresNonCanonicalEntries() async throws { + let rustBackend = ZcashRustBackendWeldingMock() + let transactionEncoder = RetryTestTransactionEncoder() + transactionEncoder.createResults = [.success([])] + + let canonicalNote = PIRNotePosition(id: 5, position: 100, value: 50_000) + let nonCanonicalNote = PIRNotePosition(id: -3, position: 200, value: 30_000) + rustBackend.getPIRWitnessNotesReturnValue = [canonicalNote, nonCanonicalNote] + + let canonicalWitness = PIRWitnessEntry( + noteId: 5, + position: 100, + siblings: Array(repeating: String(repeating: "aa", count: 32), count: 32), + anchorHeight: 2_000, + anchorRoot: String(repeating: "bb", count: 32) + ) + let nonCanonicalWitness = PIRWitnessEntry( + noteId: -3, + position: 200, + siblings: Array(repeating: String(repeating: "cc", count: 32), count: 32), + anchorHeight: 2_000, + anchorRoot: String(repeating: "dd", count: 32) + ) + + let synchronizer = try await makeSynchronizer( + rustBackend: rustBackend, + transactionEncoder: transactionEncoder, + syncStatus: .syncing(0.5, false), + pirWitnessFetcher: { _, _, _ in + PIRWitnessResult(witnesses: [canonicalWitness, nonCanonicalWitness]) + } + ) + + var proposal = makeProposal() + proposal.pirWitnessConfig = Proposal.PIRWitnessConfig(serverURL: "http://localhost:8080") + let stream = try await synchronizer.createProposedTransactions( + proposal: proposal, + spendingKey: try makeSpendingKey(network: synchronizer.network) + ) + + var iterator = stream.makeAsyncIterator() + _ = try await iterator.next() + + XCTAssertEqual(rustBackend.insertPIRWitnessesCallsCount, 1, "Canonical witness should use insertPIRWitnesses") + XCTAssertEqual( + rustBackend.insertPIRWitnessesReceivedWitnesses, + [canonicalWitness], + "Only canonical (positive ID) witnesses should be inserted in the minimal design" + ) + } +} diff --git a/Tests/TestUtils/Sourcery/GeneratedMocks/AutoMockable.generated.swift b/Tests/TestUtils/Sourcery/GeneratedMocks/AutoMockable.generated.swift index 0e3c09288..957a322ae 100644 --- a/Tests/TestUtils/Sourcery/GeneratedMocks/AutoMockable.generated.swift +++ b/Tests/TestUtils/Sourcery/GeneratedMocks/AutoMockable.generated.swift @@ -2504,6 +2504,53 @@ class SynchronizerMock: Synchronizer { try await deleteAccountClosure!(accountUUID) } + // MARK: - checkWalletSpendability + + var checkWalletSpendabilityThrowableError: Error? + var checkWalletSpendabilityCallsCount = 0 + var checkWalletSpendabilityCalled: Bool { + return checkWalletSpendabilityCallsCount > 0 + } + var checkWalletSpendabilityReturnValue: SpendabilityResult! + var checkWalletSpendabilityClosure: ((String, SpendabilityProgressHandler?) async throws -> SpendabilityResult)? + + func checkWalletSpendability( + pirServerUrl: String, + progress: SpendabilityProgressHandler? + ) async throws -> SpendabilityResult { + if let error = checkWalletSpendabilityThrowableError { + throw error + } + checkWalletSpendabilityCallsCount += 1 + if let closure = checkWalletSpendabilityClosure { + return try await closure(pirServerUrl, progress) + } else { + return checkWalletSpendabilityReturnValue + } + } + + // MARK: - fetchNoteWitnesses + + var fetchNoteWitnessesThrowableError: Error? + var fetchNoteWitnessesCallsCount = 0 + var fetchNoteWitnessesReturnValue: WitnessResult! + var fetchNoteWitnessesClosure: ((String, SpendabilityProgressHandler?) async throws -> WitnessResult)? + + func fetchNoteWitnesses( + pirServerUrl: String, + progress: SpendabilityProgressHandler? + ) async throws -> WitnessResult { + if let error = fetchNoteWitnessesThrowableError { + throw error + } + fetchNoteWitnessesCallsCount += 1 + if let closure = fetchNoteWitnessesClosure { + return try await closure(pirServerUrl, progress) + } else { + return fetchNoteWitnessesReturnValue + } + } + } class TransactionRepositoryMock: TransactionRepository { @@ -3621,25 +3668,25 @@ class ZcashRustBackendWeldingMock: ZcashRustBackendWelding { // MARK: - createProposedTransactions - var createProposedTransactionsProposalUskThrowableError: Error? - var createProposedTransactionsProposalUskCallsCount = 0 - var createProposedTransactionsProposalUskCalled: Bool { - return createProposedTransactionsProposalUskCallsCount > 0 + var createProposedTransactionsProposalUskUsePIRWitnessesThrowableError: Error? + var createProposedTransactionsProposalUskUsePIRWitnessesCallsCount = 0 + var createProposedTransactionsProposalUskUsePIRWitnessesCalled: Bool { + return createProposedTransactionsProposalUskUsePIRWitnessesCallsCount > 0 } - var createProposedTransactionsProposalUskReceivedArguments: (proposal: FfiProposal, usk: UnifiedSpendingKey)? - var createProposedTransactionsProposalUskReturnValue: [Data]! - var createProposedTransactionsProposalUskClosure: ((FfiProposal, UnifiedSpendingKey) async throws -> [Data])? + var createProposedTransactionsProposalUskUsePIRWitnessesReceivedArguments: (proposal: FfiProposal, usk: UnifiedSpendingKey, usePIRWitnesses: Bool)? + var createProposedTransactionsProposalUskUsePIRWitnessesReturnValue: [Data]! + var createProposedTransactionsProposalUskUsePIRWitnessesClosure: ((FfiProposal, UnifiedSpendingKey, Bool) async throws -> [Data])? - func createProposedTransactions(proposal: FfiProposal, usk: UnifiedSpendingKey) async throws -> [Data] { - if let error = createProposedTransactionsProposalUskThrowableError { + func createProposedTransactions(proposal: FfiProposal, usk: UnifiedSpendingKey, usePIRWitnesses: Bool) async throws -> [Data] { + if let error = createProposedTransactionsProposalUskUsePIRWitnessesThrowableError { throw error } - createProposedTransactionsProposalUskCallsCount += 1 - createProposedTransactionsProposalUskReceivedArguments = (proposal: proposal, usk: usk) - if let closure = createProposedTransactionsProposalUskClosure { - return try await closure(proposal, usk) + createProposedTransactionsProposalUskUsePIRWitnessesCallsCount += 1 + createProposedTransactionsProposalUskUsePIRWitnessesReceivedArguments = (proposal: proposal, usk: usk, usePIRWitnesses: usePIRWitnesses) + if let closure = createProposedTransactionsProposalUskUsePIRWitnessesClosure { + return try await closure(proposal, usk, usePIRWitnesses) } else { - return createProposedTransactionsProposalUskReturnValue + return createProposedTransactionsProposalUskUsePIRWitnessesReturnValue } } @@ -3938,4 +3985,80 @@ class ZcashRustBackendWeldingMock: ZcashRustBackendWelding { try await deleteAccountClosure!(accountUUID) } + // MARK: - getUnspentOrchardNotesForPIR + + var getUnspentOrchardNotesForPIRThrowableError: Error? + var getUnspentOrchardNotesForPIRCallsCount = 0 + var getUnspentOrchardNotesForPIRReturnValue: [PIRUnspentNote] = [] + var getUnspentOrchardNotesForPIRClosure: (() async throws -> [PIRUnspentNote])? + + func getUnspentOrchardNotesForPIR() async throws -> [PIRUnspentNote] { + if let error = getUnspentOrchardNotesForPIRThrowableError { + throw error + } + getUnspentOrchardNotesForPIRCallsCount += 1 + if let closure = getUnspentOrchardNotesForPIRClosure { + return try await closure() + } else { + return getUnspentOrchardNotesForPIRReturnValue + } + } + + // MARK: - getNotesNeedingPIRWitness + + var getNotesNeedingPIRWitnessThrowableError: Error? + var getNotesNeedingPIRWitnessCallsCount = 0 + var getNotesNeedingPIRWitnessReturnValue: [PIRNotePosition] = [] + var getNotesNeedingPIRWitnessClosure: (() async throws -> [PIRNotePosition])? + + func getNotesNeedingPIRWitness() async throws -> [PIRNotePosition] { + if let error = getNotesNeedingPIRWitnessThrowableError { + throw error + } + getNotesNeedingPIRWitnessCallsCount += 1 + if let closure = getNotesNeedingPIRWitnessClosure { + return try await closure() + } else { + return getNotesNeedingPIRWitnessReturnValue + } + } + + // MARK: - getPIRWitnessNotes + + var getPIRWitnessNotesThrowableError: Error? + var getPIRWitnessNotesCallsCount = 0 + var getPIRWitnessNotesReceivedProposal: FfiProposal? + var getPIRWitnessNotesReturnValue: [PIRNotePosition] = [] + var getPIRWitnessNotesClosure: ((FfiProposal) async throws -> [PIRNotePosition])? + + func getPIRWitnessNotes(for proposal: FfiProposal) async throws -> [PIRNotePosition] { + if let error = getPIRWitnessNotesThrowableError { + throw error + } + getPIRWitnessNotesCallsCount += 1 + getPIRWitnessNotesReceivedProposal = proposal + if let closure = getPIRWitnessNotesClosure { + return try await closure(proposal) + } else { + return getPIRWitnessNotesReturnValue + } + } + + // MARK: - insertPIRWitnesses + + var insertPIRWitnessesThrowableError: Error? + var insertPIRWitnessesCallsCount = 0 + var insertPIRWitnessesReceivedWitnesses: [PIRWitnessEntry]? + var insertPIRWitnessesClosure: (([PIRWitnessEntry]) async throws -> Void)? + + func insertPIRWitnesses(_ witnesses: [PIRWitnessEntry]) async throws { + if let error = insertPIRWitnessesThrowableError { + throw error + } + insertPIRWitnessesCallsCount += 1 + insertPIRWitnessesReceivedWitnesses = witnesses + try await insertPIRWitnessesClosure?(witnesses) + } + + } diff --git a/Tests/TestUtils/Stubs.swift b/Tests/TestUtils/Stubs.swift index 1eb4f372a..a3f1a0ebb 100644 --- a/Tests/TestUtils/Stubs.swift +++ b/Tests/TestUtils/Stubs.swift @@ -85,7 +85,7 @@ class RustBackendMockHelper { rustBackendMock.proposeTransferAccountUUIDToValueMemoThrowableError = ZcashError.rustCreateToAddress("mocked error") let error = ZcashError.rustShieldFunds("mocked error") rustBackendMock.proposeShieldingAccountUUIDMemoShieldingThresholdTransparentReceiverThrowableError = error - rustBackendMock.createProposedTransactionsProposalUskThrowableError = ZcashError.rustCreateToAddress("mocked error") + rustBackendMock.createProposedTransactionsProposalUskUsePIRWitnessesThrowableError = ZcashError.rustCreateToAddress("mocked error") rustBackendMock.decryptAndStoreTransactionTxBytesMinedHeightThrowableError = ZcashError.rustDecryptAndStoreTransaction("mock fail") rustBackendMock.initDataDbSeedClosure = { seed in diff --git a/rust/CHANGELOG.md b/rust/CHANGELOG.md index 6a73419c4..f36372323 100644 --- a/rust/CHANGELOG.md +++ b/rust/CHANGELOG.md @@ -6,6 +6,19 @@ and this library adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +### Added +- PIR (Private Information Retrieval) FFI for nullifier spend-checking and witness retrieval: + - `zcashlc_check_nullifiers_pir` — batch-check nullifiers against a PIR server. + - `zcashlc_fetch_pir_witnesses` — fetch Orchard note commitment witnesses from a PIR server. + - `zcashlc_get_unspent_orchard_notes_for_pir` — return unspent Orchard notes with nullifiers for PIR queries. + - `zcashlc_get_notes_needing_pir_witness` — return canonical notes that need a PIR witness fetch. + - `zcashlc_get_pir_witness_notes_for_proposal` — return notes selected by a proposal that may need witness refresh. + - `zcashlc_insert_pir_witnesses` — store PIR-obtained witnesses in the wallet DB. + + +### Changed +- `zcashlc_create_proposed_transactions` now accepts a `use_pir_witnesses` parameter to read Orchard witnesses from PIR-stored data instead of the local ShardTree. + ## 2.4.6 - 2026-03-12 ### Changed diff --git a/rust/src/lib.rs b/rust/src/lib.rs index 770a41153..d8592fff3 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -94,6 +94,7 @@ use zip32::fingerprint::SeedFingerprint; mod derivation; mod ffi; +mod spendability; mod tor; #[cfg(target_vendor = "apple")] @@ -2474,6 +2475,9 @@ pub unsafe extern "C" fn zcashlc_propose_shielding( /// - `output_params`: A pointer to a buffer containing the operating system path of the Sapling /// output proving parameters, in the operating system's preferred path representation. /// - `output_params_len`: the length of the `output_params` buffer. +/// - `use_pir_witnesses`: When `true`, Orchard witnesses are read from PIR-stored data +/// instead of the local ShardTree. `spendability-pir` feature must be enabled. Otherwise, +/// this parameter is ignored. /// /// # Safety /// @@ -2522,6 +2526,7 @@ pub unsafe extern "C" fn zcashlc_create_proposed_transactions( output_params: *const u8, output_params_len: usize, network_id: u32, + use_pir_witnesses: bool, ) -> *mut ffi::TxIds { let res = catch_panic(|| { let network = parse_network(network_id)?; @@ -2549,6 +2554,7 @@ pub unsafe extern "C" fn zcashlc_create_proposed_transactions( &SpendingKeys::from_unified_spending_key(usk), OvkPolicy::Sender, &proposal, + use_pir_witnesses, ) .map_err(|e| anyhow!("Error while sending funds: {}", e))?; @@ -4121,9 +4127,404 @@ fn free_ptr_from_vec_with(ptr: *mut T, len: usize, f: impl Fn(&mut T)) { } } +// --------------------------------------------------------------------------- +// PIR FFI — DB-touching operations that go through wallet_db() so they share +// the same connection pattern as every other zcashlc_* function. +// --------------------------------------------------------------------------- + +/// Returns unspent Orchard notes with nullifiers for PIR spend-checking. +/// +/// Returns JSON `[{"id":i64,"nf":[u8, ...],"value":u64}, ...]`, or null on error. +/// +/// # Safety +/// +/// - `db_data` must be non-null and valid for reads for `db_data_len` bytes. +#[unsafe(no_mangle)] +pub unsafe extern "C" fn zcashlc_get_unspent_orchard_notes_for_pir( + db_data: *const u8, + db_data_len: usize, + network_id: u32, +) -> *mut ffi::BoxedSlice { + let res = catch_panic(|| { + let network = parse_network(network_id)?; + let db_data = unsafe { wallet_db(db_data, db_data_len, network)? }; + + let notes = db_data + .get_unspent_orchard_notes_for_pir() + .map_err(|e| anyhow!("failed to query unspent orchard notes for PIR: {e}"))?; + + #[derive(serde::Serialize)] + struct Note { + id: i64, + nf: Vec, + value: u64, + } + + let out: Vec = notes + .into_iter() + .map(|n| Note { + id: n.id, + nf: n.nf.to_vec(), + value: n.value, + }) + .collect(); + + let json = serde_json::to_vec(&out)?; + Ok(ffi::BoxedSlice::some(json)) + }); + unwrap_exc_or_null(res) +} + +// --------------------------------------------------------------------------- +// Witness PIR FFI — DB-touching operations for PIR note commitment witnesses. +// --------------------------------------------------------------------------- + +/// JSON wire type shared by witness-related FFI endpoints that return +/// `[{"id":i64,"position":u64,"value":u64}, ...]`. +#[derive(serde::Serialize)] +struct PirNotePosition { + id: i64, + position: u64, + value: u64, +} + +/// Returns canonical Orchard notes that should be considered for PIR witness +/// fetch or refresh. +/// +/// Returns JSON `[{"id":i64,"position":u64,"value":u64}, ...]`, or null on error. +/// +/// # Safety +/// +/// - `db_data` must be non-null and point to a path-encoded byte array of length `db_data_len`. +#[unsafe(no_mangle)] +pub unsafe extern "C" fn zcashlc_get_notes_needing_pir_witness( + db_data: *const u8, + db_data_len: usize, + network_id: u32, +) -> *mut ffi::BoxedSlice { + let res = catch_panic(|| { + let network = parse_network(network_id)?; + let db_data = unsafe { wallet_db(db_data, db_data_len, network)? }; + + let notes = db_data + .get_notes_needing_pir_witness() + .map_err(|e| anyhow!("failed to query notes needing PIR witness: {e}"))?; + + let out: Vec = notes + .into_iter() + .map(|n| PirNotePosition { + id: n.id, + position: n.position, + value: n.value, + }) + .collect(); + + let json = serde_json::to_vec(&out)?; + Ok(ffi::BoxedSlice::some(json)) + }); + unwrap_exc_or_null(res) +} + +/// Returns Orchard notes referenced by a proposal that can be refreshed via witness PIR. +/// +/// Returns JSON `[{"id":i64,"position":u64,"value":u64}, ...]`, or null on error. +/// +/// # Safety +/// +/// - `db_data` must be non-null and point to a path-encoded byte array of length `db_data_len`. +/// - `proposal_ptr` must be non-null and point to a valid encoded `Proposal` protobuf. +#[unsafe(no_mangle)] +pub unsafe extern "C" fn zcashlc_get_pir_witness_notes_for_proposal( + db_data: *const u8, + db_data_len: usize, + proposal_ptr: *const u8, + proposal_len: usize, + network_id: u32, +) -> *mut ffi::BoxedSlice { + let res = catch_panic(|| { + let network = parse_network(network_id)?; + let db_data = unsafe { wallet_db(db_data, db_data_len, network)? }; + + let proposal = + Proposal::decode(unsafe { slice::from_raw_parts(proposal_ptr, proposal_len) }) + .map_err(|e| anyhow!("Invalid proposal: {}", e))? + .try_into_standard_proposal(&db_data)?; + + let out: Vec = db_data + .get_pir_witness_notes_for_proposal(&proposal) + .into_iter() + .map(|note| PirNotePosition { + id: note.id, + position: note.position, + value: note.value, + }) + .collect(); + + let json = serde_json::to_vec(&out)?; + Ok(ffi::BoxedSlice::some(json)) + }); + unwrap_exc_or_null(res) +} + +#[derive(Debug, serde::Deserialize)] +struct WitnessInput { + note_id: i64, + siblings: Vec, + anchor_height: u64, + anchor_root: String, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct DecodedWitnessInput { + note_id: i64, + siblings: [[u8; 32]; 32], + anchor_height: u64, + anchor_root: [u8; 32], +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct WitnessValidationSummary { + computed_root: [u8; 32], +} + +/// Deserializes and validates hex-encoded PIR witness JSON into typed inputs. +/// +/// Each sibling and anchor root is decoded from hex and length-checked (32 bytes). +fn parse_pir_witness_inputs(json_bytes: &[u8]) -> anyhow::Result> { + let inputs: Vec = serde_json::from_slice(json_bytes) + .map_err(|e| anyhow!("failed to parse witnesses JSON: {e}"))?; + + inputs + .into_iter() + .map(|input| { + if input.siblings.len() != 32 { + return Err(anyhow!( + "witness for note {} has {} siblings, expected 32", + input.note_id, + input.siblings.len() + )); + } + + let mut siblings = [[0u8; 32]; 32]; + for (i, hex_str) in input.siblings.iter().enumerate() { + let bytes = hex::decode(hex_str).map_err(|e| { + anyhow!("invalid hex in sibling {i} for note {}: {e}", input.note_id) + })?; + if bytes.len() != 32 { + return Err(anyhow!( + "sibling {i} for note {} is {} bytes, expected 32", + input.note_id, + bytes.len() + )); + } + siblings[i].copy_from_slice(&bytes); + } + + let anchor_root_bytes = hex::decode(&input.anchor_root).map_err(|e| { + anyhow!("invalid hex in anchor_root for note {}: {e}", input.note_id) + })?; + let anchor_root: [u8; 32] = anchor_root_bytes + .try_into() + .map_err(|_| anyhow!("anchor_root for note {} is not 32 bytes", input.note_id))?; + + Ok(DecodedWitnessInput { + note_id: input.note_id, + siblings, + anchor_height: input.anchor_height, + anchor_root, + }) + }) + .collect() +} + +/// Validates each witness input against its note commitment, then inserts it. +/// +/// For each input, `validate` recomputes the Merkle root from the note and siblings; +/// if the root matches the provided anchor, `insert` persists the witness. A mismatch +/// is an immediate hard error (no partial writes). +fn apply_pir_witness_inputs( + inputs: Vec, + mut validate: Validate, + mut insert: Insert, +) -> anyhow::Result<()> +where + Validate: FnMut(&DecodedWitnessInput) -> anyhow::Result, + Insert: FnMut(&DecodedWitnessInput) -> anyhow::Result<()>, +{ + for input in inputs { + let validation = validate(&input)?; + let witness_root_matches_anchor = validation.computed_root == input.anchor_root; + + if !witness_root_matches_anchor { + tracing::warn!( + note_id = input.note_id, + anchor_height = input.anchor_height, + provided_anchor_root = %hex::encode(input.anchor_root), + computed_root = %hex::encode(validation.computed_root), + "witness FFI: rejecting PIR witness because computed root did not match provided anchor", + ); + return Err(anyhow!( + "PIR witness for note {} failed root validation before insert", + input.note_id + )); + } + + insert(&input)?; + } + + Ok(()) +} + +/// Inserts PIR-obtained witness data for notes into the wallet DB. +/// +/// `witnesses_json` is a JSON array of `WitnessInput` objects. +/// +/// Returns 0 on success, -1 on error. +/// +/// # Safety +/// +/// - `db_data` must be non-null and point to a path-encoded byte array of length `db_data_len`. +/// - `witnesses_json` must be non-null and point to a valid JSON byte array of length `witnesses_json_len`. +#[unsafe(no_mangle)] +pub unsafe extern "C" fn zcashlc_insert_pir_witnesses( + db_data: *const u8, + db_data_len: usize, + network_id: u32, + witnesses_json: *const u8, + witnesses_json_len: usize, +) -> i32 { + let res = catch_panic(|| { + let network = parse_network(network_id)?; + let db_data = unsafe { wallet_db(db_data, db_data_len, network)? }; + let json_bytes = unsafe { std::slice::from_raw_parts(witnesses_json, witnesses_json_len) }; + let inputs = parse_pir_witness_inputs(json_bytes)?; + apply_pir_witness_inputs( + inputs, + |input| { + let validation = db_data + .validate_pir_orchard_witness( + input.note_id, + &input.siblings, + input.anchor_height, + &input.anchor_root, + ) + .map_err(|e| { + anyhow!( + "failed to validate PIR witness for note {} before insert: {e}", + input.note_id + ) + })?; + + Ok(WitnessValidationSummary { + computed_root: validation.computed_root, + }) + }, + |input| { + db_data + .insert_pir_witness( + input.note_id, + &input.siblings, + input.anchor_height, + &input.anchor_root, + ) + .map_err(|e| { + anyhow!( + "failed to insert PIR witness for note {}: {e}", + input.note_id + ) + }) + }, + )?; + Ok(0i32) + }); + crate::unwrap_exc_or(res, -1) +} + pub(crate) fn parse_optional_height(value: i64) -> anyhow::Result> { Ok(match value { -1 => None, _ => Some(BlockHeight::try_from(value)?), }) } + +#[cfg(test)] +mod tests { + use super::*; + use std::cell::Cell; + + fn witness_json(note_id: i64) -> Vec { + let json = serde_json::json!([{ + "note_id": note_id, + "siblings": vec!["00".repeat(32); 32], + "anchor_height": 100u64, + "anchor_root": "11".repeat(32) + }]); + serde_json::to_vec(&json).expect("serialize witness json") + } + + fn validation_summary(matches_anchor: bool) -> WitnessValidationSummary { + WitnessValidationSummary { + computed_root: if matches_anchor { [0x11; 32] } else { [6u8; 32] }, + } + } + + #[test] + fn parse_pir_witness_inputs_decodes_hex_payload() { + let inputs = parse_pir_witness_inputs(&witness_json(7)).expect("parse witness json"); + assert_eq!(inputs.len(), 1); + assert_eq!(inputs[0].note_id, 7); + assert_eq!(inputs[0].anchor_height, 100); + assert_eq!(inputs[0].siblings.len(), 32); + assert_eq!(inputs[0].anchor_root, [0x11; 32]); + } + + #[test] + fn apply_pir_witness_inputs_rejects_invalid_witness_before_insert() { + let inputs = parse_pir_witness_inputs(&witness_json(9)).expect("parse witness json"); + let insert_calls = Cell::new(0usize); + + let err = apply_pir_witness_inputs( + inputs, + |_input| Ok(validation_summary(false)), + |_input| { + insert_calls.set(insert_calls.get() + 1); + Ok(()) + }, + ) + .expect_err("invalid witness should be rejected before insert"); + + assert!( + err.to_string() + .contains("PIR witness for note 9 failed root validation before insert"), + "unexpected error: {err}" + ); + assert_eq!( + insert_calls.get(), + 0, + "insert should not be attempted after validation failure" + ); + } + + #[test] + fn apply_pir_witness_inputs_inserts_valid_witnesses() { + let inputs = parse_pir_witness_inputs(&witness_json(11)).expect("parse witness json"); + let validate_calls = Cell::new(0usize); + let insert_calls = Cell::new(0usize); + + apply_pir_witness_inputs( + inputs, + |_input| { + validate_calls.set(validate_calls.get() + 1); + Ok(validation_summary(true)) + }, + |_input| { + insert_calls.set(insert_calls.get() + 1); + Ok(()) + }, + ) + .expect("valid witness should be inserted"); + + assert_eq!(validate_calls.get(), 1); + assert_eq!(insert_calls.get(), 1); + } +} diff --git a/rust/src/spendability.rs b/rust/src/spendability.rs new file mode 100644 index 000000000..d21a178e0 --- /dev/null +++ b/rust/src/spendability.rs @@ -0,0 +1,266 @@ +//! C FFI for spendability & witness PIR — network-only calls. +//! +//! DB read/write operations are handled by the `zcashlc_*` functions +//! in `lib.rs` that go through `wallet_db()` and share the `@DBActor` +//! connection. + +use std::panic::AssertUnwindSafe; + +use anyhow::anyhow; +use ffi_helpers::panic::catch_panic; +use serde::{Deserialize, Serialize}; + +use spend_types::SpendMetadata; + +use crate::unwrap_exc_or_null; + +pub(crate) unsafe fn str_from_ptr(ptr: *const u8, len: usize) -> anyhow::Result { + let bytes = unsafe { std::slice::from_raw_parts(ptr, len) }; + Ok(std::str::from_utf8(bytes)?.to_string()) +} + +pub(crate) fn json_to_boxed_slice( + value: &T, +) -> anyhow::Result<*mut crate::ffi::BoxedSlice> { + let json = serde_json::to_vec(value)?; + Ok(crate::ffi::BoxedSlice::some(json)) +} + +#[derive(Serialize)] +struct NullifierCheckResult { + earliest_height: u64, + latest_height: u64, + /// Parallel to the input nullifiers: Some(meta) = spent, None = not spent. + spent: Vec>, +} + +/// Checks nullifiers against the PIR server. No database access. +/// +/// `nullifiers_json` is a JSON array of byte arrays (each 32 elements), +/// e.g. `[[0,1,...,31],[0,1,...,31]]`. +/// +/// Returns JSON `NullifierCheckResult`, or null on error. +/// +/// # Safety +/// +/// Pointer/length pairs must be valid UTF-8 slices. +#[unsafe(no_mangle)] +pub unsafe extern "C" fn zcashlc_check_nullifiers_pir( + pir_server_url: *const u8, + pir_server_url_len: usize, + nullifiers_json: *const u8, + nullifiers_json_len: usize, + progress_callback: Option, + progress_context: *mut std::ffi::c_void, +) -> *mut crate::ffi::BoxedSlice { + let progress_context = AssertUnwindSafe(progress_context); + let res = catch_panic(|| { + let url = unsafe { str_from_ptr(pir_server_url, pir_server_url_len) }?; + let nf_bytes = unsafe { std::slice::from_raw_parts(nullifiers_json, nullifiers_json_len) }; + + let nf_vecs: Vec> = serde_json::from_slice(nf_bytes) + .map_err(|e| anyhow!("failed to parse nullifiers JSON: {e}"))?; + + let nullifiers: Vec<[u8; 32]> = nf_vecs + .into_iter() + .map(|v| { + v.try_into() + .map_err(|_| anyhow!("nullifier must be exactly 32 bytes")) + }) + .collect::>>()?; + + let client = spend_client::SpendClientBlocking::connect(&url) + .map_err(|e| anyhow!("PIR connect failed: {e}"))?; + + let spent = client + .check_nullifiers(&nullifiers, |progress| { + if let Some(cb) = progress_callback { + unsafe { cb(progress, *progress_context) }; + } + }) + .map_err(|e| anyhow!("PIR check failed: {e}"))?; + + let metadata = client.metadata(); + let result = NullifierCheckResult { + earliest_height: metadata.earliest_height, + latest_height: metadata.latest_height, + spent, + }; + + json_to_boxed_slice(&result) + }); + unwrap_exc_or_null(res) +} + +// ── Witness PIR ────────────────────────────────────────────────────── + +#[derive(Deserialize)] +struct PositionInput { + note_id: i64, + position: u64, +} + +#[derive(Serialize)] +struct WitnessEntry { + note_id: i64, + position: u64, + /// 32 siblings, each 32 bytes, hex-encoded. + siblings: Vec, + anchor_height: u64, + anchor_root: String, +} + +#[derive(Serialize)] +struct WitnessCheckResult { + witnesses: Vec, +} + +/// Fetches note commitment witnesses from the PIR server. No database access. +/// +/// `positions_json` is a JSON array of `{"note_id": i64, "position": u64}`. +/// +/// Returns JSON `WitnessCheckResult`, or null on error. +/// +/// # Safety +/// +/// Pointer/length pairs must be valid UTF-8 slices. +#[unsafe(no_mangle)] +pub unsafe extern "C" fn zcashlc_fetch_pir_witnesses( + pir_server_url: *const u8, + pir_server_url_len: usize, + positions_json: *const u8, + positions_json_len: usize, + progress_callback: Option, + progress_context: *mut std::ffi::c_void, +) -> *mut crate::ffi::BoxedSlice { + let progress_context = AssertUnwindSafe(progress_context); + let res = catch_panic(|| { + let t0 = std::time::Instant::now(); + let url = unsafe { str_from_ptr(pir_server_url, pir_server_url_len) }?; + let pos_bytes = unsafe { std::slice::from_raw_parts(positions_json, positions_json_len) }; + + let inputs: Vec = serde_json::from_slice(pos_bytes) + .map_err(|e| anyhow!("failed to parse positions JSON: {e}"))?; + + if inputs.is_empty() { + return json_to_boxed_slice(&WitnessCheckResult { witnesses: vec![] }); + } + + tracing::info!(num_notes = inputs.len(), url = %url, "witness FFI: starting"); + + let t1 = std::time::Instant::now(); + let client = witness_client::WitnessClientBlocking::connect(&url) + .map_err(|e| anyhow!("witness PIR connect failed: {e}"))?; + tracing::info!( + elapsed_ms = t1.elapsed().as_millis(), + "witness FFI: connected" + ); + + let positions: Vec = inputs.iter().map(|i| i.position).collect(); + + let t2 = std::time::Instant::now(); + let pir_witnesses = client + .get_witnesses(&positions, |frac| { + if let Some(cb) = progress_callback { + unsafe { cb(frac, *progress_context) }; + } + }) + .map_err(|e| anyhow!("PIR witness batch query failed: {e}"))?; + tracing::info!( + elapsed_ms = t2.elapsed().as_millis(), + count = pir_witnesses.len(), + "witness FFI: queries complete", + ); + + let witnesses: Vec = inputs + .iter() + .zip(pir_witnesses.iter()) + .map(|(input, w)| WitnessEntry { + note_id: input.note_id, + position: input.position, + siblings: w.siblings.iter().map(hex::encode).collect(), + anchor_height: w.anchor_height, + anchor_root: hex::encode(w.anchor_root), + }) + .collect(); + + tracing::info!( + total_ms = t0.elapsed().as_millis(), + num_witnesses = witnesses.len(), + "witness FFI: done", + ); + + json_to_boxed_slice(&WitnessCheckResult { witnesses }) + }); + unwrap_exc_or_null(res) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn serialize_nullifier_check_result() { + let result = NullifierCheckResult { + earliest_height: 100, + latest_height: 200, + spent: vec![ + Some(SpendMetadata { + spend_height: 150, + first_output_position: 5000, + action_count: 3, + }), + None, + Some(SpendMetadata { + spend_height: 180, + first_output_position: 8000, + action_count: 1, + }), + ], + }; + let json: serde_json::Value = serde_json::to_value(&result).unwrap(); + assert_eq!(json["earliest_height"], 100); + assert_eq!(json["latest_height"], 200); + assert_eq!(json["spent"][0]["spend_height"], 150); + assert_eq!(json["spent"][0]["first_output_position"], 5000); + assert_eq!(json["spent"][0]["action_count"], 3); + assert!(json["spent"][1].is_null()); + assert_eq!(json["spent"][2]["spend_height"], 180); + } + + #[test] + fn serialize_empty_nullifier_check_result() { + let result = NullifierCheckResult { + earliest_height: 0, + latest_height: 0, + spent: vec![], + }; + let json: serde_json::Value = serde_json::to_value(&result).unwrap(); + assert_eq!(json["spent"], serde_json::json!([])); + } + + #[test] + fn serialize_witness_check_result() { + let result = WitnessCheckResult { + witnesses: vec![WitnessEntry { + note_id: 42, + position: 1000, + siblings: vec!["aa".repeat(32)], + anchor_height: 3200000, + anchor_root: "bb".repeat(32), + }], + }; + let json: serde_json::Value = serde_json::to_value(&result).unwrap(); + assert_eq!(json["witnesses"][0]["note_id"], 42); + assert_eq!(json["witnesses"][0]["position"], 1000); + assert_eq!(json["witnesses"][0]["anchor_height"], 3200000); + assert_eq!(json["witnesses"][0]["anchor_root"], "bb".repeat(32)); + } + + #[test] + fn serialize_empty_witness_check_result() { + let result = WitnessCheckResult { witnesses: vec![] }; + let json: serde_json::Value = serde_json::to_value(&result).unwrap(); + assert_eq!(json["witnesses"], serde_json::json!([])); + } +}