Skip to content

Commit

Permalink
refactor: Use more serde_untagged
Browse files Browse the repository at this point in the history
I felt this does a good job of cleaning up the code and by using it
more, new uses are more likely to use it.

Due to an error reporting limitation in `serde_untagged`, I'm not using
this for some `MaybeWorkspace` types because it makes the errors worse.

I also held off on some config visitors because they seemed more
complicated and I didn't want to risk that code.
  • Loading branch information
epage committed Aug 28, 2023
1 parent 77a9b2d commit 4da5fa5
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 208 deletions.
23 changes: 5 additions & 18 deletions src/cargo/util/interning.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use serde::{Serialize, Serializer};
use serde_untagged::UntaggedEnumVisitor;
use std::borrow::Borrow;
use std::cmp::Ordering;
use std::collections::HashSet;
Expand Down Expand Up @@ -150,28 +151,14 @@ impl Serialize for InternedString {
}
}

struct InternedStringVisitor;

impl<'de> serde::Deserialize<'de> for InternedString {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
deserializer.deserialize_str(InternedStringVisitor)
}
}

impl<'de> serde::de::Visitor<'de> for InternedStringVisitor {
type Value = InternedString;

fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter.write_str("an String like thing")
}

fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(InternedString::new(v))
UntaggedEnumVisitor::new()
.expecting("an String like thing")
.string(|value| Ok(InternedString::new(value)))
.deserialize(deserializer)
}
}
24 changes: 5 additions & 19 deletions src/cargo/util/semver_ext.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use semver::{Comparator, Op, Version, VersionReq};
use serde_untagged::UntaggedEnumVisitor;
use std::fmt::{self, Display};

#[derive(PartialEq, Eq, Hash, Clone, Debug)]
Expand Down Expand Up @@ -198,25 +199,10 @@ impl<'de> serde::Deserialize<'de> for PartialVersion {
where
D: serde::Deserializer<'de>,
{
struct VersionVisitor;

impl<'de> serde::de::Visitor<'de> for VersionVisitor {
type Value = PartialVersion;

fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter.write_str("SemVer version")
}

fn visit_str<E>(self, string: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
string.parse().map_err(serde::de::Error::custom)
}
}

let s = String::deserialize(deserializer)?;
s.parse().map_err(serde::de::Error::custom)
UntaggedEnumVisitor::new()
.expecting("SemVer version")
.string(|value| value.parse().map_err(serde::de::Error::custom))
.deserialize(deserializer)
}
}

Expand Down
233 changes: 62 additions & 171 deletions src/cargo/util/toml/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use std::collections::{BTreeMap, BTreeSet, HashMap};
use std::ffi::OsStr;
use std::fmt::{self, Display, Write};
use std::marker::PhantomData;
use std::path::{Path, PathBuf};
use std::rc::Rc;
use std::str::{self, FromStr};
Expand Down Expand Up @@ -213,34 +212,14 @@ impl<'de, P: Deserialize<'de> + Clone> de::Deserialize<'de> for TomlDependency<P
where
D: de::Deserializer<'de>,
{
struct TomlDependencyVisitor<P>(PhantomData<P>);

impl<'de, P: Deserialize<'de> + Clone> de::Visitor<'de> for TomlDependencyVisitor<P> {
type Value = TomlDependency<P>;

fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter.write_str(
"a version string like \"0.9.8\" or a \
UntaggedEnumVisitor::new()
.expecting(
"a version string like \"0.9.8\" or a \
detailed dependency like { version = \"0.9.8\" }",
)
}

fn visit_str<E>(self, s: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(TomlDependency::Simple(s.to_owned()))
}

fn visit_map<V>(self, map: V) -> Result<Self::Value, V::Error>
where
V: de::MapAccess<'de>,
{
let mvd = de::value::MapAccessDeserializer::new(map);
DetailedTomlDependency::deserialize(mvd).map(TomlDependency::Detailed)
}
}
deserializer.deserialize_any(TomlDependencyVisitor(PhantomData))
)
.string(|value| Ok(TomlDependency::Simple(value.to_owned())))
.map(|value| value.deserialize().map(TomlDependency::Detailed))
.deserialize(deserializer)
}
}

Expand Down Expand Up @@ -400,39 +379,22 @@ impl<'de> de::Deserialize<'de> for TomlOptLevel {
where
D: de::Deserializer<'de>,
{
struct Visitor;

impl<'de> de::Visitor<'de> for Visitor {
type Value = TomlOptLevel;

fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter.write_str("an optimization level")
}

fn visit_i64<E>(self, value: i64) -> Result<TomlOptLevel, E>
where
E: de::Error,
{
Ok(TomlOptLevel(value.to_string()))
}

fn visit_str<E>(self, value: &str) -> Result<TomlOptLevel, E>
where
E: de::Error,
{
use serde::de::Error as _;
UntaggedEnumVisitor::new()
.expecting("an optimization level")
.i64(|value| Ok(TomlOptLevel(value.to_string())))
.string(|value| {
if value == "s" || value == "z" {
Ok(TomlOptLevel(value.to_string()))
} else {
Err(E::custom(format!(
Err(serde_untagged::de::Error::custom(format!(
"must be `0`, `1`, `2`, `3`, `s` or `z`, \
but found the string: \"{}\"",
value
)))
}
}
}

d.deserialize_any(Visitor)
})
.deserialize(d)
}
}

Expand Down Expand Up @@ -477,58 +439,48 @@ impl<'de> de::Deserialize<'de> for TomlDebugInfo {
where
D: de::Deserializer<'de>,
{
struct Visitor;

impl<'de> de::Visitor<'de> for Visitor {
type Value = TomlDebugInfo;

fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter.write_str(
"a boolean, 0, 1, 2, \"line-tables-only\", or \"line-directives-only\"",
)
}

fn visit_i64<E>(self, value: i64) -> Result<TomlDebugInfo, E>
where
E: de::Error,
{
use serde::de::Error as _;
let expecting = "a boolean, 0, 1, 2, \"line-tables-only\", or \"line-directives-only\"";
UntaggedEnumVisitor::new()
.expecting(expecting)
.bool(|value| {
Ok(if value {
TomlDebugInfo::Full
} else {
TomlDebugInfo::None
})
})
.i64(|value| {
let debuginfo = match value {
0 => TomlDebugInfo::None,
1 => TomlDebugInfo::Limited,
2 => TomlDebugInfo::Full,
_ => return Err(de::Error::invalid_value(Unexpected::Signed(value), &self)),
_ => {
return Err(serde_untagged::de::Error::invalid_value(
Unexpected::Signed(value),
&expecting,
))
}
};
Ok(debuginfo)
}

fn visit_bool<E>(self, v: bool) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(if v {
TomlDebugInfo::Full
} else {
TomlDebugInfo::None
})
}

fn visit_str<E>(self, value: &str) -> Result<TomlDebugInfo, E>
where
E: de::Error,
{
})
.string(|value| {
let debuginfo = match value {
"none" => TomlDebugInfo::None,
"limited" => TomlDebugInfo::Limited,
"full" => TomlDebugInfo::Full,
"line-directives-only" => TomlDebugInfo::LineDirectivesOnly,
"line-tables-only" => TomlDebugInfo::LineTablesOnly,
_ => return Err(de::Error::invalid_value(Unexpected::Str(value), &self)),
_ => {
return Err(serde_untagged::de::Error::invalid_value(
Unexpected::Str(value),
&expecting,
))
}
};
Ok(debuginfo)
}
}

d.deserialize_any(Visitor)
})
.deserialize(d)
}
}

Expand Down Expand Up @@ -927,32 +879,11 @@ impl<'de> de::Deserialize<'de> for StringOrVec {
where
D: de::Deserializer<'de>,
{
struct Visitor;

impl<'de> de::Visitor<'de> for Visitor {
type Value = StringOrVec;

fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter.write_str("string or list of strings")
}

fn visit_str<E>(self, s: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(StringOrVec(vec![s.to_string()]))
}

fn visit_seq<V>(self, v: V) -> Result<Self::Value, V::Error>
where
V: de::SeqAccess<'de>,
{
let seq = de::value::SeqAccessDeserializer::new(v);
Vec::deserialize(seq).map(StringOrVec)
}
}

deserializer.deserialize_any(Visitor)
UntaggedEnumVisitor::new()
.expecting("string or list of strings")
.string(|value| Ok(StringOrVec(vec![value.to_owned()])))
.seq(|value| value.deserialize().map(StringOrVec))
.deserialize(deserializer)
}
}

Expand All @@ -975,8 +906,8 @@ impl<'de> Deserialize<'de> for StringOrBool {
D: de::Deserializer<'de>,
{
UntaggedEnumVisitor::new()
.string(|s| Ok(StringOrBool::String(s.to_owned())))
.bool(|b| Ok(StringOrBool::Bool(b)))
.string(|s| Ok(StringOrBool::String(s.to_owned())))
.deserialize(deserializer)
}
}
Expand All @@ -993,68 +924,28 @@ impl<'de> de::Deserialize<'de> for VecStringOrBool {
where
D: de::Deserializer<'de>,
{
struct Visitor;

impl<'de> de::Visitor<'de> for Visitor {
type Value = VecStringOrBool;

fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter.write_str("a boolean or vector of strings")
}

fn visit_seq<V>(self, v: V) -> Result<Self::Value, V::Error>
where
V: de::SeqAccess<'de>,
{
let seq = de::value::SeqAccessDeserializer::new(v);
Vec::deserialize(seq).map(VecStringOrBool::VecString)
}

fn visit_bool<E>(self, b: bool) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(VecStringOrBool::Bool(b))
}
}

deserializer.deserialize_any(Visitor)
UntaggedEnumVisitor::new()
.expecting("a boolean or vector of strings")
.bool(|value| Ok(VecStringOrBool::Bool(value)))
.seq(|value| value.deserialize().map(VecStringOrBool::VecString))
.deserialize(deserializer)
}
}

fn version_trim_whitespace<'de, D>(deserializer: D) -> Result<MaybeWorkspaceSemverVersion, D::Error>
where
D: de::Deserializer<'de>,
{
struct Visitor;

impl<'de> de::Visitor<'de> for Visitor {
type Value = MaybeWorkspaceSemverVersion;

fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter.write_str("SemVer version")
}

fn visit_str<E>(self, string: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
match string.trim().parse().map_err(de::Error::custom) {
UntaggedEnumVisitor::new()
.expecting("SemVer version")
.string(
|value| match value.trim().parse().map_err(de::Error::custom) {
Ok(parsed) => Ok(MaybeWorkspace::Defined(parsed)),
Err(e) => Err(e),
}
}

fn visit_map<V>(self, map: V) -> Result<Self::Value, V::Error>
where
V: de::MapAccess<'de>,
{
let mvd = de::value::MapAccessDeserializer::new(map);
TomlWorkspaceField::deserialize(mvd).map(MaybeWorkspace::Workspace)
}
}

deserializer.deserialize_any(Visitor)
},
)
.map(|value| value.deserialize().map(MaybeWorkspace::Workspace))
.deserialize(deserializer)
}

/// This Trait exists to make [`MaybeWorkspace::Workspace`] generic. It makes deserialization of
Expand Down

0 comments on commit 4da5fa5

Please sign in to comment.