From e9298dfed96fa28fd3d242fb4491c35a2811c2e7 Mon Sep 17 00:00:00 2001 From: Billy Chan Date: Sun, 12 Dec 2021 23:10:35 +0800 Subject: [PATCH] codegen ActiveEnum & create enum from ActiveEnum --- sea-orm-cli/Cargo.toml | 2 +- sea-orm-codegen/Cargo.toml | 2 +- sea-orm-codegen/src/entity/active_enum.rs | 31 +++++++++++ sea-orm-codegen/src/entity/column.rs | 35 ++++++------ sea-orm-codegen/src/entity/mod.rs | 2 + sea-orm-codegen/src/entity/transformer.rs | 27 +++++++++- sea-orm-codegen/src/entity/writer.rs | 65 ++++++++++++++++++++--- src/schema/entity.rs | 45 ++++++++++++---- tests/common/features/schema.rs | 2 +- 9 files changed, 173 insertions(+), 38 deletions(-) create mode 100644 sea-orm-codegen/src/entity/active_enum.rs diff --git a/sea-orm-cli/Cargo.toml b/sea-orm-cli/Cargo.toml index 6f67ccd64..0c8a06605 100644 --- a/sea-orm-cli/Cargo.toml +++ b/sea-orm-cli/Cargo.toml @@ -22,7 +22,7 @@ clap = { version = "^2.33.3" } dotenv = { version = "^0.15" } async-std = { version = "^1.9", features = [ "attributes" ] } sea-orm-codegen = { version = "^0.4.2", path = "../sea-orm-codegen" } -sea-schema = { version = "^0.2.9", default-features = false, features = [ +sea-schema = { version = "0.3.0", default-features = false, features = [ "debug-print", "sqlx-mysql", "sqlx-postgres", diff --git a/sea-orm-codegen/Cargo.toml b/sea-orm-codegen/Cargo.toml index a4b0a57b8..2f2cfeaf7 100644 --- a/sea-orm-codegen/Cargo.toml +++ b/sea-orm-codegen/Cargo.toml @@ -15,7 +15,7 @@ name = "sea_orm_codegen" path = "src/lib.rs" [dependencies] -sea-query = { version = "^0.16.4" } +sea-query = { version = "0.20.0" } syn = { version = "^1", default-features = false, features = [ "derive", "parsing", diff --git a/sea-orm-codegen/src/entity/active_enum.rs b/sea-orm-codegen/src/entity/active_enum.rs new file mode 100644 index 000000000..d92423f25 --- /dev/null +++ b/sea-orm-codegen/src/entity/active_enum.rs @@ -0,0 +1,31 @@ +use heck::CamelCase; +use proc_macro2::TokenStream; +use quote::{format_ident, quote}; + +#[derive(Clone, Debug)] +pub struct ActiveEnum { + pub(crate) enum_name: String, + pub(crate) values: Vec, +} + +impl ActiveEnum { + pub fn impl_active_enum(&self) -> TokenStream { + let enum_name = &self.enum_name; + let enum_iden = format_ident!("{}", enum_name.to_camel_case()); + let values = &self.values; + let variants = self + .values + .iter() + .map(|v| format_ident!("{}", v.to_camel_case())); + quote! { + #[derive(Debug, Clone, PartialEq, EnumIter, DeriveActiveEnum)] + #[sea_orm(rs_type = "String", db_type = "Enum", enum_name = #enum_name)] + pub enum #enum_iden { + #( + #[sea_orm(string_value = #values)] + #variants, + )* + } + } + } +} diff --git a/sea-orm-codegen/src/entity/column.rs b/sea-orm-codegen/src/entity/column.rs index 39eb340cb..359046008 100644 --- a/sea-orm-codegen/src/entity/column.rs +++ b/sea-orm-codegen/src/entity/column.rs @@ -24,26 +24,27 @@ impl Column { pub fn get_rs_type(&self) -> TokenStream { #[allow(unreachable_patterns)] - let ident: TokenStream = match self.col_type { + let ident: TokenStream = match &self.col_type { ColumnType::Char(_) | ColumnType::String(_) | ColumnType::Text - | ColumnType::Custom(_) => "String", - ColumnType::TinyInteger(_) => "i8", - ColumnType::SmallInteger(_) => "i16", - ColumnType::Integer(_) => "i32", - ColumnType::BigInteger(_) => "i64", - ColumnType::Float(_) => "f32", - ColumnType::Double(_) => "f64", - ColumnType::Json | ColumnType::JsonBinary => "Json", - ColumnType::Date => "Date", - ColumnType::Time(_) => "Time", - ColumnType::DateTime(_) | ColumnType::Timestamp(_) => "DateTime", - ColumnType::TimestampWithTimeZone(_) => "DateTimeWithTimeZone", - ColumnType::Decimal(_) | ColumnType::Money(_) => "Decimal", - ColumnType::Uuid => "Uuid", - ColumnType::Binary(_) => "Vec", - ColumnType::Boolean => "bool", + | ColumnType::Custom(_) => "String".to_owned(), + ColumnType::TinyInteger(_) => "i8".to_owned(), + ColumnType::SmallInteger(_) => "i16".to_owned(), + ColumnType::Integer(_) => "i32".to_owned(), + ColumnType::BigInteger(_) => "i64".to_owned(), + ColumnType::Float(_) => "f32".to_owned(), + ColumnType::Double(_) => "f64".to_owned(), + ColumnType::Json | ColumnType::JsonBinary => "Json".to_owned(), + ColumnType::Date => "Date".to_owned(), + ColumnType::Time(_) => "Time".to_owned(), + ColumnType::DateTime(_) | ColumnType::Timestamp(_) => "DateTime".to_owned(), + ColumnType::TimestampWithTimeZone(_) => "DateTimeWithTimeZone".to_owned(), + ColumnType::Decimal(_) | ColumnType::Money(_) => "Decimal".to_owned(), + ColumnType::Uuid => "Uuid".to_owned(), + ColumnType::Binary(_) => "Vec".to_owned(), + ColumnType::Boolean => "bool".to_owned(), + ColumnType::Enum(name, _) => name.to_camel_case(), _ => unimplemented!(), } .parse() diff --git a/sea-orm-codegen/src/entity/mod.rs b/sea-orm-codegen/src/entity/mod.rs index ee0ee830c..cb4619044 100644 --- a/sea-orm-codegen/src/entity/mod.rs +++ b/sea-orm-codegen/src/entity/mod.rs @@ -1,3 +1,4 @@ +mod active_enum; mod base_entity; mod column; mod conjunct_relation; @@ -6,6 +7,7 @@ mod relation; mod transformer; mod writer; +pub use active_enum::*; pub use base_entity::*; pub use column::*; pub use conjunct_relation::*; diff --git a/sea-orm-codegen/src/entity/transformer.rs b/sea-orm-codegen/src/entity/transformer.rs index 9a998b4b9..8c9cdd1c1 100644 --- a/sea-orm-codegen/src/entity/transformer.rs +++ b/sea-orm-codegen/src/entity/transformer.rs @@ -1,5 +1,6 @@ use crate::{ - Column, ConjunctRelation, Entity, EntityWriter, Error, PrimaryKey, Relation, RelationType, + ActiveEnum, Column, ConjunctRelation, Entity, EntityWriter, Error, PrimaryKey, Relation, + RelationType, }; use sea_query::TableStatement; use std::collections::HashMap; @@ -9,6 +10,7 @@ pub struct EntityTransformer; impl EntityTransformer { pub fn transform(table_stmts: Vec) -> Result { + let mut enums: HashMap = HashMap::new(); let mut inverse_relations: HashMap> = HashMap::new(); let mut conjunct_relations: HashMap> = HashMap::new(); let mut entities = HashMap::new(); @@ -22,7 +24,15 @@ impl EntityTransformer { } }; let table_name = match table_create.get_table_name() { - Some(s) => s, + Some(table_ref) => match table_ref { + sea_query::TableRef::Table(t) + | sea_query::TableRef::SchemaTable(_, t) + | sea_query::TableRef::DatabaseSchemaTable(_, _, t) + | sea_query::TableRef::TableAlias(t, _) + | sea_query::TableRef::SchemaTableAlias(_, t, _) + | sea_query::TableRef::DatabaseSchemaTableAlias(_, _, t, _) => t.to_string(), + _ => unimplemented!(), + }, None => { return Err(Error::TransformError( "Table name should not be empty".into(), @@ -44,6 +54,18 @@ impl EntityTransformer { > 0; col }) + .map(|col| { + if let sea_query::ColumnType::Enum(enum_name, values) = &col.col_type { + enums.insert( + enum_name.clone(), + ActiveEnum { + enum_name: enum_name.clone(), + values: values.clone(), + }, + ); + } + col + }) .collect(); let mut ref_table_counts: HashMap = HashMap::new(); let relations: Vec = table_create @@ -170,6 +192,7 @@ impl EntityTransformer { } Ok(EntityWriter { entities: entities.into_iter().map(|(_, v)| v).collect(), + enums, }) } } diff --git a/sea-orm-codegen/src/entity/writer.rs b/sea-orm-codegen/src/entity/writer.rs index a3648da6e..c767623f8 100644 --- a/sea-orm-codegen/src/entity/writer.rs +++ b/sea-orm-codegen/src/entity/writer.rs @@ -1,13 +1,14 @@ -use std::str::FromStr; - -use crate::Entity; +use crate::{ActiveEnum, Entity}; +use heck::CamelCase; use proc_macro2::TokenStream; -use quote::quote; +use quote::{format_ident, quote}; +use std::{collections::HashMap, str::FromStr}; use syn::{punctuated::Punctuated, token::Comma}; #[derive(Clone, Debug)] pub struct EntityWriter { pub(crate) entities: Vec, + pub(crate) enums: HashMap, } pub struct WriterOutput { @@ -83,6 +84,9 @@ impl EntityWriter { files.extend(self.write_entities(expanded_format, with_serde)); files.push(self.write_mod()); files.push(self.write_prelude()); + if !self.enums.is_empty() { + files.push(self.write_sea_orm_active_enums()); + } WriterOutput { files } } @@ -122,6 +126,14 @@ impl EntityWriter { ); lines.push("".to_owned()); Self::write(&mut lines, code_blocks); + if !self.enums.is_empty() { + Self::write( + &mut lines, + vec![quote! { + pub mod sea_orm_active_enums; + }], + ); + } OutputFile { name: "mod.rs".to_owned(), content: lines.join("\n"), @@ -143,6 +155,28 @@ impl EntityWriter { } } + pub fn write_sea_orm_active_enums(&self) -> OutputFile { + let mut lines = Vec::new(); + Self::write_doc_comment(&mut lines); + Self::write( + &mut lines, + vec![quote! { + use sea_orm::entity::prelude::*; + }], + ); + lines.push("".to_owned()); + let code_blocks = self + .enums + .iter() + .map(|(_, active_enum)| active_enum.impl_active_enum()) + .collect(); + Self::write(&mut lines, code_blocks); + OutputFile { + name: "sea_orm_active_enums.rs".to_owned(), + content: lines.join("\n"), + } + } + pub fn write(lines: &mut Vec, code_blocks: Vec) { lines.extend( code_blocks @@ -163,8 +197,10 @@ impl EntityWriter { } pub fn gen_expanded_code_blocks(entity: &Entity, with_serde: &WithSerde) -> Vec { + let mut imports = Self::gen_import(with_serde); + imports.extend(Self::gen_import_active_enum(entity)); let mut code_blocks = vec![ - Self::gen_import(with_serde), + imports, Self::gen_entity_struct(), Self::gen_impl_entity_name(entity), Self::gen_model_struct(entity, with_serde), @@ -182,8 +218,10 @@ impl EntityWriter { } pub fn gen_compact_code_blocks(entity: &Entity, with_serde: &WithSerde) -> Vec { + let mut imports = Self::gen_import(with_serde); + imports.extend(Self::gen_import_active_enum(entity)); let mut code_blocks = vec![ - Self::gen_import(with_serde), + imports, Self::gen_compact_model_struct(entity, with_serde), ]; let relation_defs = if entity.get_relation_enum_name().is_empty() { @@ -249,6 +287,21 @@ impl EntityWriter { } } + pub fn gen_import_active_enum(entity: &Entity) -> TokenStream { + entity + .columns + .iter() + .fold(TokenStream::new(), |mut ts, col| { + if let sea_query::ColumnType::Enum(enum_name, _) = &col.col_type { + let enum_name = format_ident!("{}", enum_name.to_camel_case()); + ts.extend(vec![quote! { + use super::sea_orm_active_enums::#enum_name; + }]); + } + ts + }) + } + pub fn gen_model_struct(entity: &Entity, with_serde: &WithSerde) -> TokenStream { let column_names_snake_case = entity.get_column_names_snake_case(); let column_rs_types = entity.get_column_rs_types(); diff --git a/src/schema/entity.rs b/src/schema/entity.rs index cc185e69d..9c422ed34 100644 --- a/src/schema/entity.rs +++ b/src/schema/entity.rs @@ -1,6 +1,6 @@ use crate::{ - unpack_table_ref, ColumnTrait, ColumnType, DbBackend, EntityTrait, Identity, Iterable, - PrimaryKeyToColumn, PrimaryKeyTrait, RelationTrait, Schema, + unpack_table_ref, ActiveEnum, ColumnTrait, ColumnType, DbBackend, EntityTrait, Identity, + Iterable, PrimaryKeyToColumn, PrimaryKeyTrait, RelationTrait, Schema, }; use sea_query::{ extension::postgres::{Type, TypeCreateStatement}, @@ -8,6 +8,14 @@ use sea_query::{ }; impl Schema { + /// Creates Postgres enums from an ActiveEnum. See [TypeCreateStatement] for more details + pub fn create_enum_from_active_enum(&self) -> TypeCreateStatement + where + A: ActiveEnum, + { + create_enum_from_active_enum::(self.backend) + } + /// Creates Postgres enums from an Entity. See [TypeCreateStatement] for more details pub fn create_enum_from_entity(&self, entity: E) -> Vec where @@ -25,6 +33,30 @@ impl Schema { } } +pub(crate) fn create_enum_from_active_enum(backend: DbBackend) -> TypeCreateStatement +where + A: ActiveEnum, +{ + if matches!(backend, DbBackend::MySql | DbBackend::Sqlite) { + panic!("TypeCreateStatement is not supported in MySQL & SQLite"); + } + let col_def = A::db_type(); + let col_type = col_def.get_column_type(); + create_enum_from_column_type(col_type) +} + +pub(crate) fn create_enum_from_column_type(col_type: &ColumnType) -> TypeCreateStatement { + let (name, values) = match col_type { + ColumnType::Enum(s, v) => (s.as_str(), v), + _ => panic!("Should be ColumnType::Enum"), + }; + Type::create() + .as_enum(Alias::new(name)) + .values(values.iter().map(|val| Alias::new(val.as_str()))) + .to_owned() +} + +#[allow(clippy::needless_borrow)] pub(crate) fn create_enum_from_entity(_: E, backend: DbBackend) -> Vec where E: EntityTrait, @@ -39,14 +71,7 @@ where if !matches!(col_type, ColumnType::Enum(_, _)) { continue; } - let (name, values) = match col_type { - ColumnType::Enum(s, v) => (s.as_str(), v), - _ => unreachable!(), - }; - let stmt = Type::create() - .as_enum(Alias::new(name)) - .values(values.iter().map(|val| Alias::new(val.as_str()))) - .to_owned(); + let stmt = create_enum_from_column_type(&col_type); vec.push(stmt); } vec diff --git a/tests/common/features/schema.rs b/tests/common/features/schema.rs index 91712c917..684659647 100644 --- a/tests/common/features/schema.rs +++ b/tests/common/features/schema.rs @@ -4,7 +4,7 @@ use super::*; use crate::common::setup::{create_enum, create_table, create_table_without_asserts}; use sea_orm::{ error::*, sea_query, ConnectionTrait, DatabaseConnection, DbBackend, DbConn, EntityName, - ExecResult, + ExecResult, Schema, }; use sea_query::{extension::postgres::Type, Alias, ColumnDef, ForeignKeyCreateStatement};