From 4da2a9ca2bef37a235bd1c89dc2d2c34efc1c01e Mon Sep 17 00:00:00 2001 From: Jonathan Strong Date: Thu, 4 Mar 2021 13:22:40 -0500 Subject: [PATCH] adds `read_npz` and `write_npz`, convenience wrappers for `NpzWriter` and `NpzReader` similar to `read_npy` and `write_npy` --- resources/array.npz | Bin 0 -> 221 bytes src/lib.rs | 2 +- src/npy/mod.rs | 72 ++++++++++++++++++++++++++++++++++++++++++++ tests/examples.rs | 27 +++++++++++++++++ 4 files changed, 100 insertions(+), 1 deletion(-) create mode 100644 resources/array.npz diff --git a/resources/array.npz b/resources/array.npz new file mode 100644 index 0000000000000000000000000000000000000000..151b5c009a6f0a3ca6a526a8d3309a20b339bf1b GIT binary patch literal 221 zcmWIWW@Zs#U|`??Vnv34Exs+GK-LT(=423INGvLfH_*!~sAOai01E=e89)j^VD=mF zSJDC}PXrtbSUYc0)SQ?lc}o}MP0n1jZt;S7)8a#drpyqZA3tRhmr%L)8BIUdF7b2~ zu_a7DrQEEzj?J>-+RBx}tjDI1z>=h(()) +/// ``` +pub fn write_npz(path: P, array: &ArrayBase) -> Result<(), crate::WriteNpzError> +where + P: AsRef, + S::Elem: WritableElement, + S: Data, + D: Dimension +{ + let file = std::fs::File::create(path) + .map_err(|e| crate::WriteNpzError::Npy(WriteNpyError::Io(e)))?; + let mut wtr = crate::NpzWriter::new_compressed(file); + wtr.add_array("arr_0.npy", array)?; + Ok(()) +} + +/// Read an array from a `.npz` file located at the specified path and name. +/// +/// This is a convience function for opening a file and using `NpzReader` to +/// extract one array from it. +/// +/// The name of a single array written to an `.npz` file using `write_npz` +/// will be "arr_0.npy", following numpy's conventions for labeling unnamed +/// arrays in `savez_compressed`. +/// +/// # Example +/// +/// ``` +/// use ndarray::Array2; +/// use ndarray_npy::read_npz; +/// # use ndarray_npy::ReadNpzError; +/// let arr: Array2 = read_npz("resources/array.npz", "arr_0.npy")?; +/// # println!("arr = {}", arr); +/// # Ok::<_, ReadNpzError>(()) +/// ``` +pub fn read_npz(path: P, name: N) -> Result, crate::ReadNpzError> +where + P: AsRef, + N: Into, + S::Elem: ReadableElement, + S: DataOwned, + D: Dimension, +{ + let file = std::fs::File::open(path) + .map_err(|e| crate::ReadNpzError::Npy(ReadNpyError::Io(e)))?; + let mut rdr = crate::NpzReader::new(file)?; + let name: String = name.into(); + let arr = rdr.by_name(&name)?; + Ok(arr) +} + /// Writes an `.npy` file (sparse if possible) with bitwise-zero-filled data. /// /// The `.npy` file represents an array with element type `A` and shape diff --git a/tests/examples.rs b/tests/examples.rs index 5f62784..d672c59 100644 --- a/tests/examples.rs +++ b/tests/examples.rs @@ -302,3 +302,30 @@ fn zeroed() { assert_eq!(arr, Array3::::zeros(SHAPE)); assert!(arr.is_standard_layout()); } + +#[test] +fn convenience_functions_round_trip_f64_standard() { + let mut arr = Array3::::zeros((2, 3, 4)); + for (i, elem) in arr.iter_mut().enumerate() { + *elem = (i as f64).sin() * std::f64::consts::PI; + } + + let tmp = tempfile::tempdir().unwrap(); + + // npy round trip + let npy_path = tmp.path().join("f64-example.npy"); + ndarray_npy::write_npy(&npy_path, &arr).unwrap(); + assert!(npy_path.exists()); + let rt_arr: Array3 = ndarray_npy::read_npy(&npy_path).unwrap(); + assert_eq!(arr, rt_arr); + + // npz round trip + let npz_path = tmp.path().join("f64-example.npz"); + ndarray_npy::write_npz(&npz_path, &arr).unwrap(); + assert!(npz_path.exists()); + let rtz_arr: Array3 = ndarray_npy::read_npz(&npz_path, "arr_0.npy").unwrap(); + assert_eq!(arr, rtz_arr); + tmp.close().unwrap(); +} + +