Skip to content

Commit

Permalink
Restrict special ok/error handling to variant names
Browse files Browse the repository at this point in the history
This prevents ok/error atoms from getting mangled during transcoding.
Also adds a sanity check so only Result variants get special treatment
instead of any variant named "Ok" or "Err"
  • Loading branch information
benhaney committed Jul 14, 2024
1 parent ef20458 commit c183948
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 26 deletions.
23 changes: 5 additions & 18 deletions rustler/src/serde/atoms.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@
use crate::serde::Error;
use crate::{types::atom::Atom, Encoder, Env, Term};

pub static OK: &str = "Ok";
pub static ERROR: &str = "Err";

atoms! {
nil,
ok,
Expand All @@ -23,28 +20,18 @@ atoms! {
* Attempts to create an atom term from the provided string (if the atom already exists in the atom table). If not, returns a string term.
*/
pub fn str_to_term<'a>(env: &Env<'a>, string: &str) -> Result<Term<'a>, Error> {
if string == "Ok" {
Ok(ok().encode(*env))
} else if string == "Err" {
Ok(error().encode(*env))
} else {
match Atom::try_from_bytes(*env, string.as_bytes()) {
Ok(Some(term)) => Ok(term.encode(*env)),
Ok(None) => Err(Error::InvalidStringable),
_ => Err(Error::InvalidStringable),
}
match Atom::try_from_bytes(*env, string.as_bytes()) {
Ok(Some(term)) => Ok(term.encode(*env)),
Ok(None) => Ok(string.encode(*env)),
_ => Err(Error::InvalidStringable),
}
}

/**
* Attempts to create a `String` from the term.
*/
pub fn term_to_string(term: &Term) -> Result<String, Error> {
if ok().eq(term) {
Ok(OK.to_string())
} else if error().eq(term) {
Ok(ERROR.to_string())
} else if term.is_atom() {
if term.is_atom() {
term.atom_to_string().or(Err(Error::InvalidAtom))
} else {
Err(Error::InvalidStringable)
Expand Down
9 changes: 7 additions & 2 deletions rustler/src/serde/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -728,8 +728,13 @@ impl<'de, 'a: 'de> de::Deserializer<'de> for VariantNameDeserializer<'a> {
{
match self.variant.get_type() {
TermType::Atom => {
let string =
atoms::term_to_string(&self.variant).or(Err(Error::InvalidVariantName))?;
let string = atoms::term_to_string(&self.variant)
.map(|s| match s.as_str() {
"ok" => "Ok".to_string(),
"error" => "Err".to_string(),
_ => s,
})
.or(Err(Error::InvalidVariantName))?;
visitor.visit_string(string)
}
TermType::Binary => visitor.visit_string(util::term_to_str(&self.variant)?),
Expand Down
8 changes: 4 additions & 4 deletions rustler/src/serde/ser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,17 +212,17 @@ impl<'a> ser::Serializer for Serializer<'a> {
/// `enum Result { Ok(u8), Err(_) }` into `{:ok, u8}` or `{:err, _}`.
fn serialize_newtype_variant<T>(
self,
_name: &'static str,
name: &'static str,
_variant_index: u32,
variant: &'static str,
value: &T,
) -> Result<Self::Ok, Self::Error>
where
T: ?Sized + ser::Serialize,
{
match variant {
"Ok" => self.serialize_newtype_struct("ok", value),
"Err" => self.serialize_newtype_struct("error", value),
match (name, variant) {
("Result", "Ok") => self.serialize_newtype_struct("ok", value),
("Result", "Err") => self.serialize_newtype_struct("error", value),
_ => self.serialize_newtype_struct(variant, value),
}
}
Expand Down
4 changes: 2 additions & 2 deletions rustler_tests/test/serde_rustler_tests_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -253,14 +253,14 @@ defmodule SerdeRustlerTests.NifTest do

test "newtype variant (Result::Ok(T), or {:ok, T})", ctx do
test_case = {:ok, 255}
transcoded = ["Ok", 255]
transcoded = ["ok", 255]
run_tests("newtype variant (ok tuple)", test_case, Helpers.skip(ctx, :transcode))
Helpers.run_transcode("newtype variant (ok tuple)", test_case, transcoded)
end

test "newtype variant (Result::Err(T), or {:error, T}", ctx do
test_case = {:error, "error reason"}
transcoded = ["Err", "error reason"]
transcoded = ["error", "error reason"]
run_tests("newtype variant (error tuple)", test_case, Helpers.skip(ctx, :transcode))
Helpers.run_transcode("newtype variant (error tuple)", test_case, transcoded)
end
Expand Down

0 comments on commit c183948

Please sign in to comment.