Skip to content

Commit

Permalink
update blob.rs
Browse files Browse the repository at this point in the history
  • Loading branch information
thomas-k-cameron committed Apr 27, 2023
1 parent a85cef6 commit 61275ae
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 0 deletions.
10 changes: 10 additions & 0 deletions rust-runtime/aws-smithy-types/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ serde = { version = "1", features = ["derive"] }
serde_json = "1"
criterion = "0.4"
rand = "0.8.4"
ciborium = "0.2.0"

[package.metadata.docs.rs]
all-features = true
Expand All @@ -32,3 +33,12 @@ rustdoc-args = ["--cfg", "docsrs"]
[[bench]]
name = "base64"
harness = false


[target."cfg(aws_sdk_unstable)".dependencies.serde]
version = "1"
features = ["derive"]

[features]
serde-serialize = []
serde-deserialize = []
129 changes: 129 additions & 0 deletions rust-runtime/aws-smithy-types/src/blob.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,16 @@
* SPDX-License-Identifier: Apache-2.0
*/

#[cfg(all(
aws_sdk_unstable,
any(feature = "serde-deserialize", feature = "serde-serialize")
))]
use crate::base64;
#[cfg(all(aws_sdk_unstable, feature = "serde-serialize"))]
use serde::Serialize;
#[cfg(all(aws_sdk_unstable, feature = "serde-deserialize"))]
use serde::{de::Visitor, Deserialize};

/// Binary Blob Type
///
/// Blobs represent protocol-agnostic binary content.
Expand Down Expand Up @@ -30,3 +40,122 @@ impl AsRef<[u8]> for Blob {
&self.inner
}
}

#[cfg(all(aws_sdk_unstable, feature = "serde-serialize"))]
impl Serialize for Blob {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
if serializer.is_human_readable() {
serializer.serialize_str(&crate::base64::encode(&self.inner))
} else {
serializer.serialize_bytes(&self.inner)
}
}
}

#[cfg(all(aws_sdk_unstable, feature = "serde-deserialize"))]
struct HumanReadableBlobVisitor;

#[cfg(all(aws_sdk_unstable, feature = "serde-deserialize"))]
impl<'de> Visitor<'de> for HumanReadableBlobVisitor {
type Value = Blob;
fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
formatter.write_str("expected base64 encoded string")
}

fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
match base64::decode(v) {
Ok(inner) => Ok(Blob { inner }),
Err(e) => Err(E::custom(e)),
}
}
}

#[cfg(all(aws_sdk_unstable, feature = "serde-deserialize"))]
struct NotHumanReadableBlobVisitor;

#[cfg(all(aws_sdk_unstable, feature = "serde-deserialize"))]
impl<'de> Visitor<'de> for NotHumanReadableBlobVisitor {
type Value = Blob;
fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
formatter.write_str("expected base64 encoded string")
}

fn visit_byte_buf<E>(self, v: Vec<u8>) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(Blob { inner: v })
}
}

#[cfg(all(aws_sdk_unstable, feature = "serde-deserialize"))]
impl<'de> Deserialize<'de> for Blob {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
if deserializer.is_human_readable() {
deserializer.deserialize_str(HumanReadableBlobVisitor)
} else {
deserializer.deserialize_byte_buf(NotHumanReadableBlobVisitor)
}
}
}

#[cfg(test)]
#[cfg(all(
aws_sdk_unstable,
feature = "serde-serialize",
feature = "serde-deserialize"
))]
mod test {
use crate::Blob;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;

#[derive(Deserialize, Serialize, Debug, PartialEq)]
struct ForTest {
blob: Blob,
}

#[test]
fn human_readable_blob() {
let aws_in_base64 = r#"{"blob":"QVdT"}"#;
let for_test = ForTest {
blob: Blob {
inner: vec![b'A', b'W', b'S'],
},
};
assert_eq!(for_test, serde_json::from_str(aws_in_base64).unwrap());
assert_eq!(serde_json::to_string(&for_test).unwrap(), aws_in_base64);
}

#[test]
fn not_human_readable_blob() {
use std::ffi::CString;

let for_test = ForTest {
blob: Blob {
inner: vec![b'A', b'W', b'S'],
},
};
let mut buf = vec![];
let res = ciborium::ser::into_writer(&for_test, &mut buf);
assert!(res.is_ok());

// checks whether the bytes are deserialiezd properly
let n: HashMap<String, CString> =
ciborium::de::from_reader(std::io::Cursor::new(buf.clone())).unwrap();
assert!(n.get("blob").is_some());
assert!(n.get("blob") == CString::new([65, 87, 83]).ok().as_ref());

let de: ForTest = ciborium::de::from_reader(std::io::Cursor::new(buf)).unwrap();
assert_eq!(for_test, de);
}
}

0 comments on commit 61275ae

Please sign in to comment.