Skip to content

Commit

Permalink
refactor: Use more serde_untagged
Browse files Browse the repository at this point in the history
  • Loading branch information
epage committed Aug 28, 2023
1 parent 94770a5 commit 17c4431
Showing 1 changed file with 55 additions and 143 deletions.
198 changes: 55 additions & 143 deletions src/cargo/util/toml/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -400,39 +400,22 @@ impl<'de> de::Deserialize<'de> for TomlOptLevel {
where
D: de::Deserializer<'de>,
{
struct Visitor;

impl<'de> de::Visitor<'de> for Visitor {
type Value = TomlOptLevel;

fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter.write_str("an optimization level")
}

fn visit_i64<E>(self, value: i64) -> Result<TomlOptLevel, E>
where
E: de::Error,
{
Ok(TomlOptLevel(value.to_string()))
}

fn visit_str<E>(self, value: &str) -> Result<TomlOptLevel, E>
where
E: de::Error,
{
use serde::de::Error as _;
UntaggedEnumVisitor::new()
.expecting("an optimization level")
.i64(|value| Ok(TomlOptLevel(value.to_string())))
.string(|value| {
if value == "s" || value == "z" {
Ok(TomlOptLevel(value.to_string()))
} else {
Err(E::custom(format!(
Err(serde_untagged::de::Error::custom(format!(
"must be `0`, `1`, `2`, `3`, `s` or `z`, \
but found the string: \"{}\"",
value
)))
}
}
}

d.deserialize_any(Visitor)
})
.deserialize(d)
}
}

Expand Down Expand Up @@ -477,58 +460,48 @@ impl<'de> de::Deserialize<'de> for TomlDebugInfo {
where
D: de::Deserializer<'de>,
{
struct Visitor;

impl<'de> de::Visitor<'de> for Visitor {
type Value = TomlDebugInfo;

fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter.write_str(
"a boolean, 0, 1, 2, \"line-tables-only\", or \"line-directives-only\"",
)
}

fn visit_i64<E>(self, value: i64) -> Result<TomlDebugInfo, E>
where
E: de::Error,
{
use serde::de::Error as _;
let expecting = "a boolean, 0, 1, 2, \"line-tables-only\", or \"line-directives-only\"";
UntaggedEnumVisitor::new()
.expecting(expecting)
.bool(|value| {
Ok(if value {
TomlDebugInfo::Full
} else {
TomlDebugInfo::None
})
})
.i64(|value| {
let debuginfo = match value {
0 => TomlDebugInfo::None,
1 => TomlDebugInfo::Limited,
2 => TomlDebugInfo::Full,
_ => return Err(de::Error::invalid_value(Unexpected::Signed(value), &self)),
_ => {
return Err(serde_untagged::de::Error::invalid_value(
Unexpected::Signed(value),
&expecting,
))
}
};
Ok(debuginfo)
}

fn visit_bool<E>(self, v: bool) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(if v {
TomlDebugInfo::Full
} else {
TomlDebugInfo::None
})
}

fn visit_str<E>(self, value: &str) -> Result<TomlDebugInfo, E>
where
E: de::Error,
{
})
.string(|value| {
let debuginfo = match value {
"none" => TomlDebugInfo::None,
"limited" => TomlDebugInfo::Limited,
"full" => TomlDebugInfo::Full,
"line-directives-only" => TomlDebugInfo::LineDirectivesOnly,
"line-tables-only" => TomlDebugInfo::LineTablesOnly,
_ => return Err(de::Error::invalid_value(Unexpected::Str(value), &self)),
_ => {
return Err(serde_untagged::de::Error::invalid_value(
Unexpected::Str(value),
&expecting,
))
}
};
Ok(debuginfo)
}
}

d.deserialize_any(Visitor)
})
.deserialize(d)
}
}

Expand Down Expand Up @@ -927,32 +900,11 @@ impl<'de> de::Deserialize<'de> for StringOrVec {
where
D: de::Deserializer<'de>,
{
struct Visitor;

impl<'de> de::Visitor<'de> for Visitor {
type Value = StringOrVec;

fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter.write_str("string or list of strings")
}

fn visit_str<E>(self, s: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(StringOrVec(vec![s.to_string()]))
}

fn visit_seq<V>(self, v: V) -> Result<Self::Value, V::Error>
where
V: de::SeqAccess<'de>,
{
let seq = de::value::SeqAccessDeserializer::new(v);
Vec::deserialize(seq).map(StringOrVec)
}
}

deserializer.deserialize_any(Visitor)
UntaggedEnumVisitor::new()
.expecting("string or list of strings")
.string(|value| Ok(StringOrVec(vec![value.to_owned()])))
.seq(|value| value.deserialize().map(StringOrVec))
.deserialize(deserializer)
}
}

Expand All @@ -975,8 +927,8 @@ impl<'de> Deserialize<'de> for StringOrBool {
D: de::Deserializer<'de>,
{
UntaggedEnumVisitor::new()
.string(|s| Ok(StringOrBool::String(s.to_owned())))
.bool(|b| Ok(StringOrBool::Bool(b)))
.string(|s| Ok(StringOrBool::String(s.to_owned())))
.deserialize(deserializer)
}
}
Expand All @@ -993,68 +945,28 @@ impl<'de> de::Deserialize<'de> for VecStringOrBool {
where
D: de::Deserializer<'de>,
{
struct Visitor;

impl<'de> de::Visitor<'de> for Visitor {
type Value = VecStringOrBool;

fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter.write_str("a boolean or vector of strings")
}

fn visit_seq<V>(self, v: V) -> Result<Self::Value, V::Error>
where
V: de::SeqAccess<'de>,
{
let seq = de::value::SeqAccessDeserializer::new(v);
Vec::deserialize(seq).map(VecStringOrBool::VecString)
}

fn visit_bool<E>(self, b: bool) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(VecStringOrBool::Bool(b))
}
}

deserializer.deserialize_any(Visitor)
UntaggedEnumVisitor::new()
.expecting("a boolean or vector of strings")
.bool(|value| Ok(VecStringOrBool::Bool(value)))
.seq(|value| value.deserialize().map(VecStringOrBool::VecString))
.deserialize(deserializer)
}
}

fn version_trim_whitespace<'de, D>(deserializer: D) -> Result<MaybeWorkspaceSemverVersion, D::Error>
where
D: de::Deserializer<'de>,
{
struct Visitor;

impl<'de> de::Visitor<'de> for Visitor {
type Value = MaybeWorkspaceSemverVersion;

fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter.write_str("SemVer version")
}

fn visit_str<E>(self, string: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
match string.trim().parse().map_err(de::Error::custom) {
UntaggedEnumVisitor::new()
.expecting("SemVer version")
.string(
|value| match value.trim().parse().map_err(de::Error::custom) {
Ok(parsed) => Ok(MaybeWorkspace::Defined(parsed)),
Err(e) => Err(e),
}
}

fn visit_map<V>(self, map: V) -> Result<Self::Value, V::Error>
where
V: de::MapAccess<'de>,
{
let mvd = de::value::MapAccessDeserializer::new(map);
TomlWorkspaceField::deserialize(mvd).map(MaybeWorkspace::Workspace)
}
}

deserializer.deserialize_any(Visitor)
},
)
.map(|value| value.deserialize().map(MaybeWorkspace::Workspace))
.deserialize(deserializer)
}

/// This Trait exists to make [`MaybeWorkspace::Workspace`] generic. It makes deserialization of
Expand Down

0 comments on commit 17c4431

Please sign in to comment.