Skip to content

Commit

Permalink
encapsulate sparse API (#352)
Browse files Browse the repository at this point in the history
Co-authored-by: Hanting Zhang <[email protected]>
  • Loading branch information
winston-h-zhang and Hanting Zhang authored Mar 12, 2024
1 parent fe39652 commit ef677bf
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 14 deletions.
6 changes: 3 additions & 3 deletions src/bellpepper/r1cs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,9 @@ fn add_constraint<S: PrimeField>(
) {
let (A, B, C, nn) = X;
let n = **nn;
assert_eq!(n + 1, A.indptr.len(), "A: invalid shape");
assert_eq!(n + 1, B.indptr.len(), "B: invalid shape");
assert_eq!(n + 1, C.indptr.len(), "C: invalid shape");
assert_eq!(n, A.num_rows(), "A: invalid shape");
assert_eq!(n, B.num_rows(), "B: invalid shape");
assert_eq!(n, C.num_rows(), "C: invalid shape");

let add_constraint_component = |index: Index, coeff: &S, M: &mut SparseMatrix<S>| {
// we add constraints to the matrix only if the associated coefficient is non-zero
Expand Down
38 changes: 38 additions & 0 deletions src/r1cs/sparse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use ff::PrimeField;
use itertools::Itertools as _;
use rand_core::{CryptoRng, RngCore};
use rayon::prelude::*;
use ref_cast::RefCast;
use serde::{Deserialize, Serialize};

/// CSR format sparse matrix, We follow the names used by scipy.
Expand All @@ -31,6 +32,11 @@ pub struct SparseMatrix<F: PrimeField> {
pub cols: usize,
}

/// Wrapper type for encode rows of [`SparseMatrix`]
#[derive(Debug, Clone, RefCast)]
#[repr(transparent)]
pub struct RowData([usize; 2]);

/// [`SparseMatrix`]s are often large, and this helps with cloning bottlenecks
impl<F: PrimeField> Clone for SparseMatrix<F> {
fn clone(&self) -> Self {
Expand Down Expand Up @@ -111,6 +117,30 @@ impl<F: PrimeField> SparseMatrix<F> {
Self::new(&matrix, rows, cols)
}

/// Returns an iterator into the rows
pub fn iter_rows(&self) -> impl Iterator<Item = &RowData> {
self
.indptr
.windows(2)
.map(|ptrs| RowData::ref_cast(ptrs.try_into().unwrap()))
}

/// Returns a parallel iterator into the rows
pub fn par_iter_rows(&self) -> impl IndexedParallelIterator<Item = &RowData> {
self
.indptr
.par_windows(2)
.map(|ptrs| RowData::ref_cast(ptrs.try_into().unwrap()))
}

/// Retrieves the data for row slice [i..j] from `row`.
/// [`RowData`] **must** be created from unmodified `self` previously to guarentee safety.
pub fn get_row(&self, row: &RowData) -> impl Iterator<Item = (&F, &usize)> {
self.data[row.0[0]..row.0[1]]
.iter()
.zip_eq(&self.indices[row.0[0]..row.0[1]])
}

/// Retrieves the data for row slice [i..j] from `ptrs`.
/// We assume that `ptrs` is indexed from `indptrs` and do not check if the
/// returned slice is actually a valid row.
Expand Down Expand Up @@ -226,6 +256,14 @@ impl<F: PrimeField> SparseMatrix<F> {
nnz: *self.indptr.last().unwrap(),
}
}

pub fn num_rows(&self) -> usize {
self.indptr.len() - 1
}

pub fn num_cols(&self) -> usize {
self.cols
}
}

/// Iterator for sparse matrix
Expand Down
8 changes: 3 additions & 5 deletions src/spartan/batched.rs
Original file line number Diff line number Diff line change
Expand Up @@ -513,13 +513,11 @@ impl<E: Engine, EE: EvaluationEngineTrait<E>> BatchedRelaxedR1CSSNARKTrait<E>
r_y: &[E::Scalar]|
-> Vec<E::Scalar> {
let evaluate_with_table =
// TODO(@winston-h-zhang): review
|M: &SparseMatrix<E::Scalar>, T_x: &[E::Scalar], T_y: &[E::Scalar]| -> E::Scalar {
M.indptr
.par_windows(2)
M.par_iter_rows()
.enumerate()
.map(|(row_idx, ptrs)| {
M.get_row_unchecked(ptrs.try_into().unwrap())
.map(|(row_idx, row)| {
M.get_row(row)
.map(|(val, col_idx)| T_x[row_idx] * T_y[*col_idx] * val)
.sum::<E::Scalar>()
})
Expand Down
5 changes: 3 additions & 2 deletions src/spartan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,9 @@ fn compute_eval_table_sparse<E: Engine>(
assert_eq!(rx.len(), S.num_cons);

let inner = |M: &SparseMatrix<E::Scalar>, M_evals: &mut Vec<E::Scalar>| {
for (row_idx, ptrs) in M.indptr.windows(2).enumerate() {
for (val, col_idx) in M.get_row_unchecked(ptrs.try_into().unwrap()) {
for (row_idx, row) in M.iter_rows().enumerate() {
for (val, col_idx) in M.get_row(row) {
// TODO(@winston-h-zhang): Parallelize? Will need more complicated locking
M_evals[*col_idx] += rx[row_idx] * val;
}
}
Expand Down
7 changes: 3 additions & 4 deletions src/spartan/snark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -351,11 +351,10 @@ impl<E: Engine, EE: EvaluationEngineTrait<E>> RelaxedR1CSSNARKTrait<E> for Relax
-> Vec<E::Scalar> {
let evaluate_with_table =
|M: &SparseMatrix<E::Scalar>, T_x: &[E::Scalar], T_y: &[E::Scalar]| -> E::Scalar {
M.indptr
.par_windows(2)
M.par_iter_rows()
.enumerate()
.map(|(row_idx, ptrs)| {
M.get_row_unchecked(ptrs.try_into().unwrap())
.map(|(row_idx, row)| {
M.get_row(row)
.map(|(val, col_idx)| T_x[row_idx] * T_y[*col_idx] * val)
.sum::<E::Scalar>()
})
Expand Down

1 comment on commit ef677bf

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmarks

Table of Contents

Overview

This benchmark report shows the Arecibo GPU benchmarks.
NVIDIA L4
Intel(R) Xeon(R) CPU @ 2.20GHz
32 vCPUs
125 GB RAM
Workflow run: https://github.com/lurk-lab/arecibo/actions/runs/8253538293

Benchmark Results

RecursiveSNARK-NIVC-2

ref=fe39652 ref=ef677bf
Prove-NumCons-6540 46.26 ms (✅ 1.00x) 44.98 ms (✅ 1.03x faster)
Verify-NumCons-6540 36.63 ms (✅ 1.00x) 34.34 ms (✅ 1.07x faster)
Prove-NumCons-1028888 322.96 ms (✅ 1.00x) 319.17 ms (✅ 1.01x faster)
Verify-NumCons-1028888 252.67 ms (✅ 1.00x) 250.30 ms (✅ 1.01x faster)

CompressedSNARK-NIVC-Commitments-2

ref=fe39652 ref=ef677bf
Prove-NumCons-6540 10.50 s (✅ 1.00x) 10.77 s (✅ 1.03x slower)
Verify-NumCons-6540 51.37 ms (✅ 1.00x) 52.39 ms (✅ 1.02x slower)
Prove-NumCons-1028888 53.69 s (✅ 1.00x) 52.88 s (✅ 1.02x faster)
Verify-NumCons-1028888 52.25 ms (✅ 1.00x) 52.19 ms (✅ 1.00x faster)

Made with criterion-table

Please sign in to comment.