-
Notifications
You must be signed in to change notification settings - Fork 99
/
batch_hasher.rs
93 lines (84 loc) · 2.84 KB
/
batch_hasher.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
use std::fmt::{self, Debug};
use std::marker::PhantomData;
use std::sync::{Arc, Mutex};
use crate::error::{ClError, Error};
use crate::poseidon::SimplePoseidonBatchHasher;
#[cfg(any(feature = "cuda", feature = "opencl"))]
use crate::proteus::gpu::ClBatchHasher;
use crate::{Arity, BatchHasher, NeptuneField, Strength, DEFAULT_STRENGTH};
use ec_gpu_gen::rust_gpu_tools::Device;
use ff::PrimeField;
use generic_array::GenericArray;
#[allow(clippy::large_enum_variant)]
pub enum Batcher<F, A>
where
F: NeptuneField,
A: Arity<F>,
{
Cpu(SimplePoseidonBatchHasher<F, A>),
#[cfg(any(feature = "cuda", feature = "opencl"))]
OpenCl(ClBatchHasher<F, A>),
}
impl<F, A> Batcher<F, A>
where
F: NeptuneField,
A: Arity<F>,
{
/// Create a new CPU batcher.
pub fn new_cpu(max_batch_size: usize) -> Self {
Self::with_strength_cpu(DEFAULT_STRENGTH, max_batch_size)
}
/// Create a new CPU batcher with a specified strength.
pub fn with_strength_cpu(strength: Strength, max_batch_size: usize) -> Self {
Self::Cpu(SimplePoseidonBatchHasher::<F, A>::new_with_strength(
strength,
max_batch_size,
))
}
/// Create a new GPU batcher for an arbitrarily picked device.
#[cfg(any(feature = "cuda", feature = "opencl"))]
pub fn pick_gpu(max_batch_size: usize) -> Result<Self, Error> {
let device = *Device::all()
.first()
.ok_or(Error::ClError(ClError::DeviceNotFound))?;
Self::new(device, max_batch_size)
}
#[cfg(any(feature = "cuda", feature = "opencl"))]
/// Create a new GPU batcher for a certain device.
pub fn new(device: &Device, max_batch_size: usize) -> Result<Self, Error> {
Self::with_strength(device, DEFAULT_STRENGTH, max_batch_size)
}
#[cfg(any(feature = "cuda", feature = "opencl"))]
/// Create a new GPU batcher for a certain device with a specified strength.
pub fn with_strength(
device: &Device,
strength: Strength,
max_batch_size: usize,
) -> Result<Self, Error> {
Ok(Self::OpenCl(ClBatchHasher::<F, A>::new_with_strength(
device,
strength,
max_batch_size,
)?))
}
}
impl<F, A> BatchHasher<F, A> for Batcher<F, A>
where
F: NeptuneField,
A: Arity<F>,
{
fn hash(&mut self, preimages: &[GenericArray<F, A>]) -> Result<Vec<F>, Error> {
match self {
Batcher::Cpu(batcher) => batcher.hash(preimages),
#[cfg(any(feature = "cuda", feature = "opencl"))]
Batcher::OpenCl(batcher) => batcher.hash(preimages),
}
}
fn max_batch_size(&self) -> usize {
match self {
Batcher::Cpu(batcher) => batcher.max_batch_size(),
#[cfg(any(feature = "cuda", feature = "opencl"))]
Batcher::OpenCl(batcher) => batcher.max_batch_size(),
}
}
}