Skip to content

Commit

Permalink
Implement FromStr for DeriveColumn
Browse files Browse the repository at this point in the history
  • Loading branch information
tqwewe authored and tyt2y3 committed Aug 21, 2021
1 parent 35c9cbe commit 93953c3
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 6 deletions.
45 changes: 43 additions & 2 deletions sea-orm-macros/src/derives/column.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use heck::SnakeCase;
use heck::{MixedCase, SnakeCase};
use proc_macro2::{Ident, TokenStream};
use quote::{quote, quote_spanned};
use quote::{format_ident, quote, quote_spanned};
use syn::{Data, DataEnum, Fields, Variant};

pub fn impl_default_as_str(ident: &Ident, data: &Data) -> syn::Result<TokenStream> {
Expand Down Expand Up @@ -41,6 +41,44 @@ pub fn impl_default_as_str(ident: &Ident, data: &Data) -> syn::Result<TokenStrea
))
}

pub fn impl_col_from_str(ident: &Ident, data: &Data) -> syn::Result<TokenStream> {
let parse_error_iden = format_ident!("Parse{}Err", ident);

let data_enum = match data {
Data::Enum(data_enum) => data_enum,
_ => {
return Ok(quote_spanned! {
ident.span() => compile_error!("you can only derive DeriveColumn on enums");
})
}
};

let columns = data_enum.variants.iter().map(|column| {
let column_iden = column.ident.clone();
let column_str_snake = column_iden.to_string().to_snake_case();
let column_str_mixed = column_iden.to_string().to_mixed_case();
quote!(
#column_str_snake | #column_str_mixed => Ok(#ident::#column_iden)
)
});

Ok(quote!(
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct #parse_error_iden;

impl std::str::FromStr for #ident {
type Err = #parse_error_iden;

fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
#(#columns),*,
_ => Err(#parse_error_iden),
}
}
}
))
}

pub fn expand_derive_column(ident: &Ident, data: &Data) -> syn::Result<TokenStream> {
let impl_iden = expand_derive_custom_column(ident, data)?;

Expand All @@ -57,10 +95,13 @@ pub fn expand_derive_column(ident: &Ident, data: &Data) -> syn::Result<TokenStre

pub fn expand_derive_custom_column(ident: &Ident, data: &Data) -> syn::Result<TokenStream> {
let impl_default_as_str = impl_default_as_str(ident, data)?;
let impl_col_from_str = impl_col_from_str(ident, data)?;

Ok(quote!(
#impl_default_as_str

#impl_col_from_str

impl sea_orm::Iden for #ident {
fn unquoted(&self, s: &mut dyn std::fmt::Write) {
write!(s, "{}", self.as_str()).unwrap();
Expand Down
25 changes: 25 additions & 0 deletions src/entity/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ mod tests {
tests_cfg::*, ColumnTrait, Condition, DbBackend, EntityTrait, QueryFilter, QueryTrait,
};
use sea_query::Query;
use std::str::FromStr;

#[test]
fn test_in_subquery() {
Expand All @@ -348,4 +349,28 @@ mod tests {
.join(" ")
);
}

#[test]
fn test_col_from_str() {
match fruit::Column::from_str("id") {
Ok(col) => assert_eq!(col, fruit::Column::Id),
Err(_) => panic!("fruit from_str fails"),
}
match fruit::Column::from_str("name") {
Ok(col) => assert_eq!(col, fruit::Column::Name),
Err(_) => panic!("fruit from_str fails"),
}
match fruit::Column::from_str("cake_id") {
Ok(col) => assert_eq!(col, fruit::Column::CakeId),
Err(_) => panic!("fruit from_str fails"),
}
match fruit::Column::from_str("cakeId") {
Ok(col) => assert_eq!(col, fruit::Column::CakeId),
Err(_) => panic!("fruit from_str fails"),
}
match fruit::Column::from_str("does_not_exist") {
Ok(_) => panic!("fruit from_str found match when it shouldn't have"),
Err(err) => assert_eq!(err, fruit::ParseColumnErr),
}
}
}
4 changes: 2 additions & 2 deletions src/query/combine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ mod tests {
.left_join(fruit::Entity)
.select_also(fruit::Entity)
.filter(cake::Column::Id.eq(1))
.filter(fruit::Column::Id.eq(2))
.filter(ColumnTrait::eq(&fruit::Column::Id, 2))
.build(DbBackend::MySql)
.to_string(),
[
Expand All @@ -186,7 +186,7 @@ mod tests {
.left_join(fruit::Entity)
.select_with(fruit::Entity)
.filter(cake::Column::Id.eq(1))
.filter(fruit::Column::Id.eq(2))
.filter(ColumnTrait::eq(&fruit::Column::Id, 2))
.build(DbBackend::MySql)
.to_string(),
[
Expand Down
2 changes: 1 addition & 1 deletion src/query/update.rs
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ mod tests {
assert_eq!(
Update::many(fruit::Entity)
.col_expr(fruit::Column::CakeId, Expr::value(Value::Int(None)))
.filter(fruit::Column::Id.eq(2))
.filter(ColumnTrait::eq(&fruit::Column::Id, 2))
.build(DbBackend::Postgres)
.to_string(),
r#"UPDATE "fruit" SET "cake_id" = NULL WHERE "fruit"."id" = 2"#,
Expand Down
2 changes: 1 addition & 1 deletion src/tests_cfg/fruit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ pub struct Model {
pub cake_id: Option<i32>,
}

#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
#[derive(Copy, Clone, PartialEq, Debug, EnumIter, DeriveColumn)]
pub enum Column {
Id,
Name,
Expand Down

0 comments on commit 93953c3

Please sign in to comment.