Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions librustvoting/src/storage/migrations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use rusqlite::Connection;

use crate::VotingError;

const CURRENT_VERSION: u32 = 3;
const CURRENT_VERSION: u32 = 4;

pub fn migrate(conn: &Connection) -> Result<(), VotingError> {
let version: u32 = conn
Expand Down Expand Up @@ -75,6 +75,30 @@ pub fn migrate(conn: &Connection) -> Result<(), VotingError> {
})?;
}

if version < 4 {
// v4: add wallet_id column for per-wallet state isolation.
// Drop everything and recreate from 001_init.sql.
conn.execute_batch(
"DROP TABLE IF EXISTS votes;
DROP TABLE IF EXISTS witnesses;
DROP TABLE IF EXISTS proofs;
DROP TABLE IF EXISTS bundles;
DROP TABLE IF EXISTS cached_tree_state;
DROP TABLE IF EXISTS rounds;"
)
.map_err(|e| VotingError::Internal {
message: format!("migration to version 4 failed (drop): {}", e),
})?;
conn.execute_batch(include_str!("migrations/001_init.sql"))
.map_err(|e| VotingError::Internal {
message: format!("migration to version 4 failed (create): {}", e),
})?;
conn.pragma_update(None, "user_version", 4)
.map_err(|e| VotingError::Internal {
message: format!("failed to update database version: {}", e),
})?;
}

let final_version: u32 = conn
.pragma_query_value(None, "user_version", |r| r.get(0))
.map_err(|e| VotingError::Internal {
Expand Down Expand Up @@ -148,13 +172,13 @@ mod tests {

// Insert a round first
conn.execute(
"INSERT INTO rounds (round_id, snapshot_height, ea_pk, nc_root, nullifier_imt_root, phase, created_at) VALUES ('test', 1, X'00', X'00', X'00', 0, 0)",
"INSERT INTO rounds (round_id, wallet_id, snapshot_height, ea_pk, nc_root, nullifier_imt_root, phase, created_at) VALUES ('test', 'w1', 1, X'00', X'00', X'00', 0, 0)",
[],
).unwrap();

// Insert a bundle row using all nullable BLOB columns.
conn.execute(
"INSERT INTO bundles (round_id, bundle_index, van_comm_rand, dummy_nullifiers, rho_signed, padded_note_data, nf_signed, cmx_new, alpha, rseed_signed, rseed_output) VALUES ('test', 0, X'AA', X'BB', X'CC', X'DD', X'EE', X'FF', X'11', X'22', X'33')",
"INSERT INTO bundles (round_id, wallet_id, bundle_index, van_comm_rand, dummy_nullifiers, rho_signed, padded_note_data, nf_signed, cmx_new, alpha, rseed_signed, rseed_output) VALUES ('test', 'w1', 0, X'AA', X'BB', X'CC', X'DD', X'EE', X'FF', X'11', X'22', X'33')",
[],
).unwrap();

Expand Down
34 changes: 22 additions & 12 deletions librustvoting/src/storage/migrations/001_init.sql
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
CREATE TABLE rounds (
round_id TEXT PRIMARY KEY,
round_id TEXT NOT NULL,
wallet_id TEXT NOT NULL DEFAULT '',
snapshot_height INTEGER NOT NULL,
ea_pk BLOB NOT NULL,
nc_root BLOB NOT NULL,
nullifier_imt_root BLOB NOT NULL,
session_json TEXT,
phase INTEGER NOT NULL DEFAULT 0,
created_at INTEGER NOT NULL
created_at INTEGER NOT NULL,
PRIMARY KEY (round_id, wallet_id)
);

CREATE TABLE bundles (
round_id TEXT NOT NULL REFERENCES rounds(round_id) ON DELETE CASCADE,
round_id TEXT NOT NULL,
wallet_id TEXT NOT NULL DEFAULT '',
bundle_index INTEGER NOT NULL,
note_positions_blob BLOB,
van_comm_rand BLOB,
Expand All @@ -30,47 +33,54 @@ CREATE TABLE bundles (
gov_nullifiers_blob BLOB,
padded_note_secrets BLOB,
pczt_sighash BLOB,
PRIMARY KEY (round_id, bundle_index)
PRIMARY KEY (round_id, wallet_id, bundle_index),
FOREIGN KEY (round_id, wallet_id) REFERENCES rounds(round_id, wallet_id) ON DELETE CASCADE
);

CREATE TABLE cached_tree_state (
round_id TEXT PRIMARY KEY REFERENCES rounds(round_id) ON DELETE CASCADE,
round_id TEXT NOT NULL,
wallet_id TEXT NOT NULL DEFAULT '',
snapshot_height INTEGER NOT NULL,
tree_state BLOB NOT NULL
tree_state BLOB NOT NULL,
PRIMARY KEY (round_id, wallet_id),
FOREIGN KEY (round_id, wallet_id) REFERENCES rounds(round_id, wallet_id) ON DELETE CASCADE
);

CREATE TABLE proofs (
round_id TEXT NOT NULL,
wallet_id TEXT NOT NULL DEFAULT '',
bundle_index INTEGER NOT NULL,
witness BLOB,
proof BLOB,
success INTEGER NOT NULL DEFAULT 0,
created_at INTEGER NOT NULL,
PRIMARY KEY (round_id, bundle_index),
FOREIGN KEY (round_id, bundle_index) REFERENCES bundles(round_id, bundle_index) ON DELETE CASCADE
PRIMARY KEY (round_id, wallet_id, bundle_index),
FOREIGN KEY (round_id, wallet_id, bundle_index) REFERENCES bundles(round_id, wallet_id, bundle_index) ON DELETE CASCADE
);

CREATE TABLE witnesses (
round_id TEXT NOT NULL,
wallet_id TEXT NOT NULL DEFAULT '',
bundle_index INTEGER NOT NULL,
note_position INTEGER NOT NULL,
note_commitment BLOB NOT NULL,
root BLOB NOT NULL,
auth_path BLOB NOT NULL,
created_at INTEGER NOT NULL,
PRIMARY KEY (round_id, bundle_index, note_position),
FOREIGN KEY (round_id, bundle_index) REFERENCES bundles(round_id, bundle_index) ON DELETE CASCADE
PRIMARY KEY (round_id, wallet_id, bundle_index, note_position),
FOREIGN KEY (round_id, wallet_id, bundle_index) REFERENCES bundles(round_id, wallet_id, bundle_index) ON DELETE CASCADE
);

CREATE TABLE votes (
id INTEGER PRIMARY KEY,
round_id TEXT NOT NULL,
wallet_id TEXT NOT NULL DEFAULT '',
bundle_index INTEGER NOT NULL,
proposal_id INTEGER NOT NULL,
choice INTEGER NOT NULL,
commitment BLOB,
submitted INTEGER NOT NULL DEFAULT 0,
created_at INTEGER NOT NULL,
UNIQUE(round_id, bundle_index, proposal_id),
FOREIGN KEY (round_id, bundle_index) REFERENCES bundles(round_id, bundle_index) ON DELETE CASCADE
UNIQUE(round_id, wallet_id, bundle_index, proposal_id),
FOREIGN KEY (round_id, wallet_id, bundle_index) REFERENCES bundles(round_id, wallet_id, bundle_index) ON DELETE CASCADE
);
109 changes: 76 additions & 33 deletions librustvoting/src/storage/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,19 +56,23 @@ pub struct VoteRecord {
#[derive(Clone, Debug)]
pub struct RoundSummary {
pub round_id: String,
pub wallet_id: String,
pub phase: RoundPhase,
pub snapshot_height: u64,
pub created_at: u64,
}

/// Database handle for voting state. Wraps a SQLite connection.
/// Database handle for voting state. Wraps a SQLite connection and a
/// wallet identifier that scopes all round data to a single wallet.
pub struct VotingDb {
conn: Mutex<Connection>,
wallet_id: Mutex<String>,
}

impl VotingDb {
/// Open (or create) the voting database at the given path.
/// Runs migrations automatically.
/// Call `set_wallet_id` before performing any round operations.
pub fn open(path: &str) -> Result<Self, VotingError> {
let conn = if path == ":memory:" {
Connection::open_in_memory()
Expand All @@ -88,9 +92,22 @@ impl VotingDb {

Ok(Self {
conn: Mutex::new(conn),
wallet_id: Mutex::new(String::new()),
})
}

/// Set the wallet identifier used to scope all subsequent operations.
pub fn set_wallet_id(&self, id: &str) {
*self.wallet_id.lock().expect("wallet_id mutex poisoned") = id.to_string();
}

/// Get the current wallet identifier. Panics if not set.
pub fn wallet_id(&self) -> String {
let id = self.wallet_id.lock().expect("wallet_id mutex poisoned").clone();
assert!(!id.is_empty(), "wallet_id must be set before performing voting operations");
id
}

/// Get a lock on the underlying connection for query execution.
pub fn conn(&self) -> std::sync::MutexGuard<'_, Connection> {
self.conn.lock().expect("database mutex poisoned")
Expand All @@ -102,6 +119,8 @@ mod tests {
use super::*;
use crate::types::VotingRoundParams;

const W: &str = "test-wallet";

fn test_db() -> VotingDb {
VotingDb::open(":memory:").unwrap()
}
Expand All @@ -123,7 +142,7 @@ mod tests {
let version: u32 = conn
.pragma_query_value(None, "user_version", |r| r.get(0))
.unwrap();
assert_eq!(version, 3);
assert_eq!(version, 4);
}

#[test]
Expand All @@ -132,94 +151,118 @@ mod tests {
let conn = db.conn();
let params = test_params();

queries::insert_round(&conn, &params, None).unwrap();
queries::insert_round(&conn, W, &params, None).unwrap();

let state = queries::get_round_state(&conn, "test-round-1").unwrap();
let state = queries::get_round_state(&conn, "test-round-1", W).unwrap();
assert_eq!(state.phase, RoundPhase::Initialized);
assert_eq!(state.snapshot_height, 1000);
assert!(!state.proof_generated);

let rounds = queries::list_rounds(&conn).unwrap();
let rounds = queries::list_rounds(&conn, W).unwrap();
assert_eq!(rounds.len(), 1);
assert_eq!(rounds[0].round_id, "test-round-1");

queries::clear_round(&conn, "test-round-1").unwrap();
let rounds = queries::list_rounds(&conn).unwrap();
queries::clear_round(&conn, "test-round-1", W).unwrap();
let rounds = queries::list_rounds(&conn, W).unwrap();
assert!(rounds.is_empty());
}

#[test]
fn test_tree_state_cache() {
let db = test_db();
let conn = db.conn();
queries::insert_round(&conn, &test_params(), None).unwrap();
queries::insert_round(&conn, W, &test_params(), None).unwrap();

let tree_state = vec![0xCC; 1024];
queries::store_tree_state(&conn, "test-round-1", 1000, &tree_state).unwrap();
queries::store_tree_state(&conn, "test-round-1", W, 1000, &tree_state).unwrap();

let loaded = queries::load_tree_state(&conn, "test-round-1").unwrap();
let loaded = queries::load_tree_state(&conn, "test-round-1", W).unwrap();
assert_eq!(loaded, tree_state);
}

#[test]
fn test_proof_storage() {
let db = test_db();
let conn = db.conn();
queries::insert_round(&conn, &test_params(), None).unwrap();
queries::insert_bundle(&conn, "test-round-1", 0, &[]).unwrap();
queries::store_proof(&conn, "test-round-1", 0, &vec![0xAB; 256]).unwrap();
queries::insert_round(&conn, W, &test_params(), None).unwrap();
queries::insert_bundle(&conn, "test-round-1", W, 0, &[]).unwrap();
queries::store_proof(&conn, "test-round-1", W, 0, &vec![0xAB; 256]).unwrap();

// proof_generated requires both proof AND VAN position
let state = queries::get_round_state(&conn, "test-round-1").unwrap();
let state = queries::get_round_state(&conn, "test-round-1", W).unwrap();
assert!(!state.proof_generated, "proof alone should not be enough");

queries::store_van_position(&conn, "test-round-1", 0, 42).unwrap();
let state = queries::get_round_state(&conn, "test-round-1").unwrap();
queries::store_van_position(&conn, "test-round-1", W, 0, 42).unwrap();
let state = queries::get_round_state(&conn, "test-round-1", W).unwrap();
assert!(state.proof_generated, "proof + VAN position should be enough");
}

#[test]
fn test_vote_storage() {
let db = test_db();
let conn = db.conn();
queries::insert_round(&conn, &test_params(), None).unwrap();
queries::insert_bundle(&conn, "test-round-1", 0, &[]).unwrap();
queries::insert_round(&conn, W, &test_params(), None).unwrap();
queries::insert_bundle(&conn, "test-round-1", W, 0, &[]).unwrap();

let commitment = vec![0xCC; 128];
queries::store_vote(&conn, "test-round-1", 0, 0, 0, &commitment).unwrap();
queries::store_vote(&conn, "test-round-1", 0, 1, 1, &commitment).unwrap();
queries::store_vote(&conn, "test-round-1", W, 0, 0, 0, &commitment).unwrap();
queries::store_vote(&conn, "test-round-1", W, 0, 1, 1, &commitment).unwrap();

queries::mark_vote_submitted(&conn, "test-round-1", 0, 0).unwrap();
queries::mark_vote_submitted(&conn, "test-round-1", W, 0, 0).unwrap();
}

#[test]
fn test_get_votes() {
let db = test_db();
let conn = db.conn();
queries::insert_round(&conn, &test_params(), None).unwrap();
queries::insert_bundle(&conn, "test-round-1", 0, &[]).unwrap();
queries::insert_round(&conn, W, &test_params(), None).unwrap();
queries::insert_bundle(&conn, "test-round-1", W, 0, &[]).unwrap();

// No votes initially
let votes = queries::get_votes(&conn, "test-round-1").unwrap();
let votes = queries::get_votes(&conn, "test-round-1", W).unwrap();
assert!(votes.is_empty());

// Store two votes with different choices
let commitment = vec![0xCC; 128];
queries::store_vote(&conn, "test-round-1", 0, 0, 0, &commitment).unwrap();
queries::store_vote(&conn, "test-round-1", 0, 1, 2, &commitment).unwrap();
queries::store_vote(&conn, "test-round-1", W, 0, 0, 0, &commitment).unwrap();
queries::store_vote(&conn, "test-round-1", W, 0, 1, 2, &commitment).unwrap();

let votes = queries::get_votes(&conn, "test-round-1").unwrap();
let votes = queries::get_votes(&conn, "test-round-1", W).unwrap();
assert_eq!(votes.len(), 2);
assert_eq!(votes[0].proposal_id, 0);
assert_eq!(votes[0].choice, 0);
assert!(!votes[0].submitted);
assert_eq!(votes[1].proposal_id, 1);
assert_eq!(votes[1].choice, 2);

// Mark first vote submitted and verify
queries::mark_vote_submitted(&conn, "test-round-1", 0, 0).unwrap();
let votes = queries::get_votes(&conn, "test-round-1").unwrap();
queries::mark_vote_submitted(&conn, "test-round-1", W, 0, 0).unwrap();
let votes = queries::get_votes(&conn, "test-round-1", W).unwrap();
assert!(votes[0].submitted);
assert!(!votes[1].submitted);
}

#[test]
fn test_wallet_isolation() {
let db = test_db();
let conn = db.conn();
let params = test_params();

queries::insert_round(&conn, "wallet-a", &params, None).unwrap();
queries::insert_round(&conn, "wallet-b", &params, None).unwrap();

queries::insert_bundle(&conn, "test-round-1", "wallet-a", 0, &[]).unwrap();
queries::insert_bundle(&conn, "test-round-1", "wallet-b", 0, &[]).unwrap();

let commitment = vec![0xCC; 128];
queries::store_vote(&conn, "test-round-1", "wallet-a", 0, 0, 1, &commitment).unwrap();
queries::store_vote(&conn, "test-round-1", "wallet-b", 0, 0, 2, &commitment).unwrap();

let votes_a = queries::get_votes(&conn, "test-round-1", "wallet-a").unwrap();
let votes_b = queries::get_votes(&conn, "test-round-1", "wallet-b").unwrap();
assert_eq!(votes_a.len(), 1);
assert_eq!(votes_b.len(), 1);
assert_eq!(votes_a[0].choice, 1);
assert_eq!(votes_b[0].choice, 2);

queries::clear_round(&conn, "test-round-1", "wallet-a").unwrap();
let rounds_b = queries::list_rounds(&conn, "wallet-b").unwrap();
assert_eq!(rounds_b.len(), 1, "wallet-b round should survive wallet-a clear");
}
Comment on lines +242 to +267
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

}
Loading