Skip to content

Commit

Permalink
Allow raw identifiers for SqlIdentifier (column-name)
Browse files Browse the repository at this point in the history
This allows using the `r#identifier` syntax for SqlIdentifier
(column-name), which is used in derive macros.
Previously the derive macros would panic when encountering such an
identifier:
    `"r#identifier"` is not a valid identifier
  • Loading branch information
z33ky committed May 22, 2024
1 parent 3c7b7c4 commit 113b6b5
Show file tree
Hide file tree
Showing 15 changed files with 761 additions and 137 deletions.
4 changes: 3 additions & 1 deletion .typos.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ extend-ignore-re = [
"cannot find value `titel` in module `posts`",
"cannot find type `titel` in module `posts`",
"[0-9]+[[:space]]+|[[:space:]]+titel: String",
"big_sur"
"big_sur",
# That's Spanish for "type" (used in a unit-test)
"tipe",
]

[type.md]
Expand Down
32 changes: 28 additions & 4 deletions diesel_compile_tests/tests/fail/derive/bad_column_name.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,38 @@ error: expected string literal
30 | #[diesel(column_name = true)]
| ^^^^

error: Expected valid identifier, found `type`. Diesel automatically renames invalid identifiers, perhaps you meant to write `type_`?
error[E0412]: cannot find type `r#type` in module `users`
--> tests/fail/derive/bad_column_name.rs:38:28
|
9 | tpe -> Text,
| --- similarly named struct `tpe` defined here
...
38 | #[diesel(column_name = "type")]
| ^^^^^^
| ^^^^^^ help: a struct with a similar name exists: `tpe`

error: Expected valid identifier, found `type`. Diesel automatically renames invalid identifiers, perhaps you meant to write `type_`?
error[E0425]: cannot find value `r#type` in module `users`
--> tests/fail/derive/bad_column_name.rs:38:28
|
9 | tpe -> Text,
| --- similarly named unit struct `tpe` defined here
...
38 | #[diesel(column_name = "type")]
| ^^^^^^ help: a unit struct with a similar name exists: `tpe`

error[E0412]: cannot find type `r#type` in module `users`
--> tests/fail/derive/bad_column_name.rs:46:28
|
9 | tpe -> Text,
| --- similarly named struct `tpe` defined here
...
46 | #[diesel(column_name = "type")]
| ^^^^^^ help: a struct with a similar name exists: `tpe`

error[E0425]: cannot find value `r#type` in module `users`
--> tests/fail/derive/bad_column_name.rs:46:28
|
9 | tpe -> Text,
| --- similarly named unit struct `tpe` defined here
...
46 | #[diesel(column_name = "type")]
| ^^^^^^
| ^^^^^^ help: a unit struct with a similar name exists: `tpe`
12 changes: 4 additions & 8 deletions diesel_derives/src/as_changeset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,7 @@ fn field_changeset_ty(
lifetime: Option<TokenStream>,
treat_none_as_null: bool,
) -> Result<TokenStream> {
let column_name = field.column_name()?;
column_name.valid_ident()?;
let column_name = field.column_name()?.to_ident()?;
if !treat_none_as_null && is_option_ty(&field.ty) {
let field_ty = inner_of_option_ty(&field.ty);
Ok(
Expand All @@ -177,8 +176,7 @@ fn field_changeset_expr(
treat_none_as_null: bool,
) -> Result<TokenStream> {
let field_name = &field.name;
let column_name = field.column_name()?;
column_name.valid_ident()?;
let column_name = field.column_name()?.to_ident()?;
if !treat_none_as_null && is_option_ty(&field.ty) {
if lifetime.is_some() {
Ok(quote!(self.#field_name.as_ref().map(|x| #table_name::#column_name.eq(x))))
Expand All @@ -196,8 +194,7 @@ fn field_changeset_ty_serialize_as(
ty: &Type,
treat_none_as_null: bool,
) -> Result<TokenStream> {
let column_name = field.column_name()?;
column_name.valid_ident()?;
let column_name = field.column_name()?.to_ident()?;
if !treat_none_as_null && is_option_ty(&field.ty) {
let inner_ty = inner_of_option_ty(ty);
Ok(quote!(std::option::Option<diesel::dsl::Eq<#table_name::#column_name, #inner_ty>>))
Expand All @@ -213,8 +210,7 @@ fn field_changeset_expr_serialize_as(
treat_none_as_null: bool,
) -> Result<TokenStream> {
let field_name = &field.name;
let column_name = field.column_name()?;
column_name.valid_ident()?;
let column_name = field.column_name()?.to_ident()?;
let column: Expr = parse_quote!(#table_name::#column_name);
if !treat_none_as_null && is_option_ty(&field.ty) {
Ok(quote!(self.#field_name.map(|x| #column.eq(::std::convert::Into::<#ty>::into(x)))))
Expand Down
28 changes: 20 additions & 8 deletions diesel_derives/src/attrs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,32 +56,42 @@ impl SqlIdentifier {
self.span
}

pub fn valid_ident(&self) -> Result<()> {
if syn::parse_str::<Ident>(&self.field_name).is_err() {
Err(syn::Error::new(
pub fn to_ident(&self) -> Result<Ident> {
match syn::parse_str::<Ident>(&format!("r#{}", self.field_name)) {
Ok(mut ident) => {
ident.set_span(self.span);
Ok(ident)
}
Err(_e) => Err(syn::Error::new(
self.span(),
format!(
"Expected valid identifier, found `{0}`. \
Diesel automatically renames invalid identifiers, \
perhaps you meant to write `{0}_`?",
self.field_name
),
))
} else {
Ok(())
)),
}
}
}

impl ToTokens for SqlIdentifier {
fn to_tokens(&self, tokens: &mut TokenStream) {
Ident::new(&self.field_name, self.span).to_tokens(tokens)
if self.field_name.starts_with("r#") {
Ident::new_raw(&self.field_name[2..], self.span).to_tokens(tokens)
} else {
Ident::new(&self.field_name, self.span).to_tokens(tokens)
}
}
}

impl Display for SqlIdentifier {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.field_name)
let mut start = 0;
if self.field_name.starts_with("r#") {
start = 2;
}
f.write_str(&self.field_name[start..])
}
}

Expand All @@ -93,6 +103,8 @@ impl PartialEq<Ident> for SqlIdentifier {

impl From<&'_ Ident> for SqlIdentifier {
fn from(ident: &'_ Ident) -> Self {
use syn::ext::IdentExt;
let ident = ident.unraw();
Self {
span: ident.span(),
field_name: ident.to_string(),
Expand Down
12 changes: 4 additions & 8 deletions diesel_derives/src/insertable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,7 @@ fn field_ty_serialize_as(
ty: &Type,
treat_none_as_default_value: bool,
) -> Result<TokenStream> {
let column_name = field.column_name()?;
column_name.valid_ident()?;
let column_name = field.column_name()?.to_ident()?;
let span = field.span;
if treat_none_as_default_value {
let inner_ty = inner_of_option_ty(ty);
Expand Down Expand Up @@ -223,8 +222,7 @@ fn field_expr_serialize_as(
treat_none_as_default_value: bool,
) -> Result<TokenStream> {
let field_name = &field.name;
let column_name = field.column_name()?;
column_name.valid_ident()?;
let column_name = field.column_name()?.to_ident()?;
let column = quote!(#table_name::#column_name);
if treat_none_as_default_value {
if is_option_ty(ty) {
Expand All @@ -245,8 +243,7 @@ fn field_ty(
lifetime: Option<TokenStream>,
treat_none_as_default_value: bool,
) -> Result<TokenStream> {
let column_name = field.column_name()?;
column_name.valid_ident()?;
let column_name = field.column_name()?.to_ident()?;
let span = field.span;
if treat_none_as_default_value {
let inner_ty = inner_of_option_ty(&field.ty);
Expand Down Expand Up @@ -276,8 +273,7 @@ fn field_expr(
treat_none_as_default_value: bool,
) -> Result<TokenStream> {
let field_name = &field.name;
let column_name = field.column_name()?;
column_name.valid_ident()?;
let column_name = field.column_name()?.to_ident()?;

let column: Expr = parse_quote!(#table_name::#column_name);
if treat_none_as_default_value {
Expand Down
2 changes: 1 addition & 1 deletion diesel_derives/src/queryable_by_name.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ fn sql_type(field: &Field, model: &Model) -> Result<Type> {
match field.sql_type {
Some(AttributeSpanWrapper { item: ref st, .. }) => Ok(st.clone()),
None => {
let column_name = field.column_name()?;
let column_name = field.column_name()?.to_ident()?;
Ok(parse_quote!(diesel::dsl::SqlTypeOf<#table_name::#column_name>))
}
}
Expand Down
4 changes: 2 additions & 2 deletions diesel_derives/src/selectable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ fn field_column_ty(
Ok(quote!(<#embed_ty as Selectable<__DB>>::SelectExpression))
} else {
let table_name = &model.table_names()[0];
let column_name = field.column_name()?;
let column_name = field.column_name()?.to_ident()?;
Ok(quote!(#table_name::#column_name))
}
}
Expand All @@ -165,7 +165,7 @@ fn field_column_inst(field: &Field, model: &Model) -> Result<TokenStream> {
Ok(quote!(<#embed_ty as Selectable<__DB>>::construct_selection()))
} else {
let table_name = &model.table_names()[0];
let column_name = field.column_name()?;
let column_name = field.column_name()?.to_ident()?;
Ok(quote!(#table_name::#column_name))
}
}
Loading

0 comments on commit 113b6b5

Please sign in to comment.