Skip to content

Commit

Permalink
Add __hash__(), __str__() and __eq()__ where appropriate
Browse files Browse the repository at this point in the history
  • Loading branch information
fjarri committed Mar 28, 2021
1 parent a9808ca commit aac9775
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 2 deletions.
1 change: 1 addition & 0 deletions umbral-pre-python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ crate-type = ["cdylib"]
pyo3 = { version = "0.13", features = ["extension-module"] }
umbral-pre = { path = "../umbral-pre" }
generic-array = "0.14"
hex = "0.4"
79 changes: 77 additions & 2 deletions umbral-pre-python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use pyo3::create_exception;
use pyo3::exceptions::{PyException, PyTypeError, PyValueError};
use pyo3::prelude::*;
use pyo3::pyclass::PyClass;
use pyo3::types::PyBytes;
use pyo3::types::{PyBytes, PyUnicode};
use pyo3::wrap_pyfunction;
use pyo3::PyObjectProtocol;

Expand Down Expand Up @@ -42,6 +42,27 @@ fn from_bytes<T: HasSerializableBackend<U>, U: SerializableToArray>(bytes: &[u8]
})
}

fn hash<T: HasSerializableBackend<U>, U: SerializableToArray>(obj: &T) -> PyResult<isize> {
let serialized = obj.as_backend().to_array();

// call `hash((class_name, bytes(obj)))`
Python::with_gil(|py| {
let builtins = PyModule::import(py, "builtins")?;
let arg1 = PyUnicode::new(py, T::name());
let arg2: PyObject = PyBytes::new(py, serialized.as_slice()).into();
builtins.getattr("hash")?.call1(((arg1, arg2),))?.extract()
})
}

// For some reason this lint is not recognized in Rust 1.46 (the one in CI)
// remove when CI is updated to a newer Rust version.
#[allow(clippy::unknown_clippy_lints)]
#[allow(clippy::unnecessary_wraps)] // Don't want to wrap it in Ok() on every call
fn hexstr<T: HasSerializableBackend<U>, U: SerializableToArray>(obj: &T) -> PyResult<String> {
let hex_str = hex::encode(obj.as_backend().to_array().as_slice());
Ok(format!("{}:{}", T::name(), &hex_str[0..16]))
}

fn richcmp<T: HasSerializableBackend<U> + PyClass + PartialEq, U>(
obj: &T,
other: PyRef<T>,
Expand Down Expand Up @@ -103,6 +124,10 @@ impl PyObjectProtocol for SecretKey {
fn __bytes__(&self) -> PyResult<PyObject> {
to_bytes(self)
}

fn __str__(&self) -> PyResult<String> {
Ok(format!("{}:...", Self::name()))
}
}

#[pyclass(module = "umbral")]
Expand Down Expand Up @@ -163,6 +188,10 @@ impl PyObjectProtocol for SecretKeyFactory {
fn __bytes__(&self) -> PyResult<PyObject> {
to_bytes(self)
}

fn __str__(&self) -> PyResult<String> {
Ok(format!("{}:...", Self::name()))
}
}

#[pyclass(module = "umbral")]
Expand Down Expand Up @@ -209,9 +238,18 @@ impl PyObjectProtocol for PublicKey {
fn __bytes__(&self) -> PyResult<PyObject> {
to_bytes(self)
}

fn __hash__(&self) -> PyResult<isize> {
hash(self)
}

fn __str__(&self) -> PyResult<String> {
hexstr(self)
}
}

#[pyclass(module = "umbral")]
#[derive(PartialEq)]
pub struct Capsule {
backend: umbral_pre::Capsule,
}
Expand Down Expand Up @@ -240,9 +278,21 @@ impl Capsule {

#[pyproto]
impl PyObjectProtocol for Capsule {
fn __richcmp__(&self, other: PyRef<Capsule>, op: CompareOp) -> PyResult<bool> {
richcmp(self, other, op)
}

fn __bytes__(&self) -> PyResult<PyObject> {
to_bytes(self)
}

fn __hash__(&self) -> PyResult<isize> {
hash(self)
}

fn __str__(&self) -> PyResult<String> {
hexstr(self)
}
}

#[pyfunction]
Expand Down Expand Up @@ -289,6 +339,7 @@ pub fn decrypt_original(
}

#[pyclass(module = "umbral")]
#[derive(PartialEq)]
pub struct KeyFrag {
backend: umbral_pre::KeyFrag,
}
Expand Down Expand Up @@ -330,9 +381,21 @@ impl KeyFrag {

#[pyproto]
impl PyObjectProtocol for KeyFrag {
fn __richcmp__(&self, other: PyRef<KeyFrag>, op: CompareOp) -> PyResult<bool> {
richcmp(self, other, op)
}

fn __bytes__(&self) -> PyResult<PyObject> {
to_bytes(self)
}

fn __hash__(&self) -> PyResult<isize> {
hash(self)
}

fn __str__(&self) -> PyResult<String> {
hexstr(self)
}
}

#[allow(clippy::too_many_arguments)]
Expand Down Expand Up @@ -364,7 +427,7 @@ pub fn generate_kfrags(
}

#[pyclass(module = "umbral")]
#[derive(Clone)]
#[derive(Clone, PartialEq)]
pub struct CapsuleFrag {
backend: umbral_pre::CapsuleFrag,
}
Expand Down Expand Up @@ -410,9 +473,21 @@ impl CapsuleFrag {

#[pyproto]
impl PyObjectProtocol for CapsuleFrag {
fn __richcmp__(&self, other: PyRef<CapsuleFrag>, op: CompareOp) -> PyResult<bool> {
richcmp(self, other, op)
}

fn __bytes__(&self) -> PyResult<PyObject> {
to_bytes(self)
}

fn __hash__(&self) -> PyResult<isize> {
hash(self)
}

fn __str__(&self) -> PyResult<String> {
hexstr(self)
}
}

#[pyfunction]
Expand Down

0 comments on commit aac9775

Please sign in to comment.