Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Nargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ authors = [""]
compiler_version = ">=0.36.0"

[dependencies]
sort = { tag = "v0.2.3", git = "https://github.com/noir-lang/noir_sort" }
sort = { tag = "v0.3.0", git = "https://github.com/noir-lang/noir_sort" }
75 changes: 35 additions & 40 deletions src/lib.nr
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
mod mut_sparse_array;
use dep::sort::sort_advanced;

unconstrained fn __sort_field_as_u32(lhs: Field, rhs: Field) -> bool {
unconstrained fn __sort(lhs: u32, rhs: u32) -> bool {
// lhs.lt(rhs)
lhs as u32 < rhs as u32
lhs < rhs
}

fn assert_sorted(lhs: Field, rhs: Field) {
let result = (rhs - lhs - 1);
result.assert_max_bit_size::<32>();
fn assert_sorted(lhs: u32, rhs: u32) {
assert(lhs < rhs);
}
Comment thread
TomAFrench marked this conversation as resolved.

/**
Expand All @@ -24,10 +23,10 @@ fn assert_sorted(lhs: Field, rhs: Field) {
**/
struct MutSparseArrayBase<let N: u32, T, ComparisonFuncs> {
values: [T; N + 3],
keys: [Field; N + 2],
linked_keys: [Field; N + 2],
tail_ptr: Field,
maximum: Field,
keys: [u32; N + 2],
linked_keys: [u32; N + 2],
tail_ptr: u32,
maximum: u32,
}

struct U32RangeTraits {}
Expand All @@ -47,9 +46,9 @@ pub struct MutSparseArray<let N: u32, T> {
* 2. values[0] is an empty object. when calling `get(idx)`, if `idx` is not in `keys` we will return `values[0]`
**/
pub struct SparseArray<let N: u32, T> {
keys: [Field; N + 2],
keys: [u32; N + 2],
values: [T; N + 3],
maximum: Field, // can be up to 2^32
maximum: u32, // can be up to 2^32 - 1
}
impl<let N: u32, T> SparseArray<N, T>
where
Expand All @@ -59,15 +58,16 @@ where
/**
* @brief construct a SparseArray
**/
pub(crate) fn create(_keys: [Field; N], _values: [T; N], size: Field) -> Self {
pub(crate) fn create(_keys: [u32; N], _values: [T; N], size: u32) -> Self {
assert(size >= 1);
let _maximum = size - 1;
let mut r: Self =
SparseArray { keys: [0; N + 2], values: [T::default(); N + 3], maximum: _maximum };

// for any valid index, we want to ensure the following is satified:
// self.keys[X] <= index <= self.keys[X+1]
// this requires us to sort hte keys, and insert a startpoint and endpoint
let sorted_keys = sort_advanced(_keys, __sort_field_as_u32, assert_sorted);
let sorted_keys = sort_advanced(_keys, __sort, assert_sorted);

// insert start and endpoints
r.keys[0] = 0;
Expand Down Expand Up @@ -103,45 +103,41 @@ where
// because `self.keys` is sorted, we can simply validate that
// sorted_keys.sorted[0] < 2^32
// sorted_keys.sorted[N-1] < maximum
sorted_keys.sorted[0].assert_max_bit_size::<32>();
_maximum.assert_max_bit_size::<32>();
(_maximum - sorted_keys.sorted[N - 1]).assert_max_bit_size::<32>();
assert(_maximum >= sorted_keys.sorted[N - 1]);
r
}

/**
* @brief determine whether `target` is present in `self.keys`
* @details if `found == false`, `self.keys[found_index] < target < self.keys[found_index + 1]`
**/
unconstrained fn search_for_key(self, target: Field) -> (Field, Field) {
unconstrained fn search_for_key(self, target: u32) -> (bool, u32) {
let mut found = false;
let mut found_index = 0;
let mut found_index: u32 = 0;
let mut previous_less_than_or_equal_to_target = false;
for i in 0..N + 2 {
// if target = 0xffffffff we need to be able to add 1 here, so use u64
let current_less_than_or_equal_to_target = self.keys[i] as u64 <= target as u64;
if (self.keys[i] == target) {
found = true;
found_index = i as Field;
found_index = i;
break;
}
if (previous_less_than_or_equal_to_target & !current_less_than_or_equal_to_target) {
found_index = i as Field - 1;
found_index = i - 1;
break;
}
previous_less_than_or_equal_to_target = current_less_than_or_equal_to_target;
}
(found as Field, found_index)
(found, found_index)
}

/**
* @brief return element `idx` from the sparse array
* @details cost is 14.5 gates per lookup
**/
fn get(self, idx: Field) -> T {
fn get(self, idx: u32) -> T {
let (found, found_index) = unsafe { self.search_for_key(idx) };
// bool check. 0.25 gates cheaper than a raw `bool` type. need to fix at some point
assert(found * found == found);

// OK! So we have the following cases to check
// 1. if `found` then `self.keys[found_index] == idx`
Expand All @@ -152,15 +148,13 @@ where
// combine the two into the following single statement:
// `self.keys[found_index] + 1 - found <= idx <= self.keys[found_index + 1 - found] - 1 + found
let lhs = self.keys[found_index];
let rhs = self.keys[found_index + 1 - found];
let lhs_condition = idx - lhs - 1 + found;
let rhs_condition = rhs - 1 + found - idx;
lhs_condition.assert_max_bit_size::<32>();
rhs_condition.assert_max_bit_size::<32>();
let rhs = self.keys[found_index + 1 - found as u32];
assert(lhs + 1 - found as u32 <= idx);
assert(idx <= rhs + found as u32 - 1);

// self.keys[i] maps to self.values[i+1]
// however...if we did not find a non-sparse entry, we want to return self.values[0] (the default value)
let value_index = (found_index + 1) * found;
let value_index = (found_index + 1) * found as u32;
self.values[value_index]
}
}
Expand All @@ -179,7 +173,7 @@ mod test {

for i in 0..100 {
if ((i != 1) & (i != 5) & (i != 7) & (i != 99)) {
assert(example.get(i as Field) == 0);
assert(example.get(i) == 0);
}
}
}
Expand All @@ -188,34 +182,35 @@ mod test {
fn test_sparse_lookup_boundary_cases() {
// what about when keys[0] = 0 and keys[N-1] = 2^32 - 1?
let example = SparseArray::create(
[0, 99999, 7, 0xffffffff],
[0, 99999, 7, 0xfffffffe],
[123, 101112, 789, 456],
0x100000000,
0xffffffff,
);

assert(example.get(0) == 123);
assert(example.get(99999) == 101112);
assert(example.get(7) == 789);
assert(example.get(0xffffffff) == 456);
assert(example.get(0xfffffffe) == 0);
assert(example.get(0xfffffffe) == 456);
assert(example.get(0xfffffffd) == 0);
}

#[test(should_fail_with = "call to assert_max_bit_size")]
#[test(should_fail)]
fn test_sparse_lookup_overflow() {
let example = SparseArray::create([1, 5, 7, 99999], [123, 456, 789, 101112], 100000);

assert(example.get(100000) == 0);
}

/**
#[test(should_fail_with = "call to assert_max_bit_size")]
fn test_sparse_lookup_boundary_case_overflow() {
let example =
SparseArray::create([0, 5, 7, 0xffffffff], [123, 456, 789, 101112], 0x100000000);

assert(example.get(0x100000000) == 0);
}

#[test(should_fail_with = "call to assert_max_bit_size")]
**/
#[test(should_fail)]
fn test_sparse_lookup_key_exceeds_maximum() {
let example =
SparseArray::create([0, 5, 7, 0xffffffff], [123, 456, 789, 101112], 0xffffffff);
Expand All @@ -236,7 +231,7 @@ mod test {

for i in 0..100 {
if ((i != 1) & (i != 5) & (i != 7) & (i != 99)) {
assert(example.get(i as Field) == 0);
assert(example.get(i) == 0);
}
}
}
Expand Down Expand Up @@ -272,7 +267,7 @@ mod test {
assert(example.get(99) == values[1]);
for i in 0..100 {
if ((i != 1) & (i != 5) & (i != 7) & (i != 99)) {
assert(example.get(i as Field) == F::default());
assert(example.get(i) == F::default());
}
}
}
Expand Down
Loading
Loading