diff --git a/src/entity/prelude.rs b/src/entity/prelude.rs index 1d68cba14..d724551e1 100644 --- a/src/entity/prelude.rs +++ b/src/entity/prelude.rs @@ -1,8 +1,9 @@ pub use crate::{ error::*, ActiveEnum, ActiveModelBehavior, ActiveModelTrait, ColumnDef, ColumnTrait, ColumnType, CursorTrait, DatabaseConnection, DbConn, EntityName, EntityTrait, EnumIter, - ForeignKeyAction, Iden, IdenStatic, Linked, ModelTrait, PaginatorTrait, PrimaryKeyToColumn, - PrimaryKeyTrait, QueryFilter, QueryResult, Related, RelationDef, RelationTrait, Select, Value, + ForeignKeyAction, Iden, IdenStatic, Linked, LoaderTrait, ModelTrait, PaginatorTrait, + PrimaryKeyToColumn, PrimaryKeyTrait, QueryFilter, QueryResult, Related, RelationDef, + RelationTrait, Select, Value, }; #[cfg(feature = "macros")] diff --git a/src/query/loader.rs b/src/query/loader.rs new file mode 100644 index 000000000..68ac1d4a4 --- /dev/null +++ b/src/query/loader.rs @@ -0,0 +1,339 @@ +use crate::{ + ColumnTrait, Condition, ConnectionTrait, DbErr, EntityTrait, Identity, ModelTrait, QueryFilter, + Related, RelationType, Select, +}; +use async_trait::async_trait; +use sea_query::{Expr, IntoColumnRef, SimpleExpr, ValueTuple}; +use std::{collections::HashMap, str::FromStr}; + +/// A trait for basic Dataloader +#[async_trait] +pub trait LoaderTrait { + /// Source model + type Model: ModelTrait; + + /// Used to eager load has_one relations + async fn load_one(&self, stmt: Select, db: &C) -> Result>, DbErr> + where + C: ConnectionTrait, + R: EntityTrait, + R::Model: Send + Sync, + <::Model as ModelTrait>::Entity: Related; + + /// Used to eager load has_many relations + async fn load_many(&self, stmt: Select, db: &C) -> Result>, DbErr> + where + C: ConnectionTrait, + R: EntityTrait, + R::Model: Send + Sync, + <::Model as ModelTrait>::Entity: Related; +} + +#[async_trait] +impl LoaderTrait for Vec +where + M: ModelTrait, + Vec: Sync, +{ + type Model = M; + + async fn load_one(&self, stmt: Select, db: &C) -> Result>, DbErr> + where + C: ConnectionTrait, + R: EntityTrait, + R::Model: Send + Sync, + <::Model as ModelTrait>::Entity: Related, + { + let rel_def = <<::Model as ModelTrait>::Entity as Related>::to(); + + // we verify that is has_one relation + match (rel_def).rel_type { + RelationType::HasOne => (), + RelationType::HasMany => { + return Err(DbErr::Type("Relation is HasMany instead of HasOne".into())) + } + } + + let keys: Vec = self + .iter() + .map(|model: &M| extract_key(&rel_def.from_col, model)) + .collect(); + + let condition = prepare_condition::<::Model>(&rel_def.to_col, &keys); + + let stmt = as QueryFilter>::filter(stmt, condition); + + let data = stmt.all(db).await?; + + let hashmap: HashMap::Model> = data.into_iter().fold( + HashMap::::Model>::new(), + |mut acc: HashMap::Model>, + value: ::Model| { + { + let key = extract_key(&rel_def.to_col, &value); + + acc.insert(format!("{:?}", key), value); + } + + acc + }, + ); + + let result: Vec::Model>> = keys + .iter() + .map(|key| hashmap.get(&format!("{:?}", key)).cloned()) + .collect(); + + Ok(result) + } + + async fn load_many(&self, stmt: Select, db: &C) -> Result>, DbErr> + where + C: ConnectionTrait, + R: EntityTrait, + R::Model: Send + Sync, + <::Model as ModelTrait>::Entity: Related, + { + let rel_def = <<::Model as ModelTrait>::Entity as Related>::to(); + + // we verify that is has_many relation + match (rel_def).rel_type { + RelationType::HasMany => (), + RelationType::HasOne => { + return Err(DbErr::Type("Relation is HasOne instead of HasMany".into())) + } + } + + let keys: Vec = self + .iter() + .map(|model: &M| extract_key(&rel_def.from_col, model)) + .collect(); + + let condition = prepare_condition::<::Model>(&rel_def.to_col, &keys); + + let stmt = as QueryFilter>::filter(stmt, condition); + + let data = stmt.all(db).await?; + + let mut hashmap: HashMap::Model>> = + keys.iter() + .fold(HashMap::new(), |mut acc, key: &ValueTuple| { + acc.insert(format!("{:?}", key), Vec::new()); + + acc + }); + + data.into_iter() + .for_each(|value: ::Model| { + let key = extract_key(&rel_def.to_col, &value); + + let vec = hashmap + .get_mut(&format!("{:?}", key)) + .expect("Failed at finding key on hashmap"); + + vec.push(value); + }); + + let result: Vec> = keys + .iter() + .map(|key: &ValueTuple| { + hashmap + .get(&format!("{:?}", key)) + .cloned() + .unwrap_or_default() + }) + .collect(); + + Ok(result) + } +} + +fn extract_key(target_col: &Identity, model: &Model) -> ValueTuple +where + Model: ModelTrait, +{ + match target_col { + Identity::Unary(a) => { + let column_a = + <<::Entity as EntityTrait>::Column as FromStr>::from_str( + &a.to_string(), + ) + .unwrap_or_else(|_| panic!("Failed at mapping string to column A:1")); + ValueTuple::One(model.get(column_a)) + } + Identity::Binary(a, b) => { + let column_a = + <<::Entity as EntityTrait>::Column as FromStr>::from_str( + &a.to_string(), + ) + .unwrap_or_else(|_| panic!("Failed at mapping string to column A:2")); + let column_b = + <<::Entity as EntityTrait>::Column as FromStr>::from_str( + &b.to_string(), + ) + .unwrap_or_else(|_| panic!("Failed at mapping string to column B:2")); + ValueTuple::Two(model.get(column_a), model.get(column_b)) + } + Identity::Ternary(a, b, c) => { + let column_a = + <<::Entity as EntityTrait>::Column as FromStr>::from_str( + &a.to_string(), + ) + .unwrap_or_else(|_| panic!("Failed at mapping string to column A:3")); + let column_b = + <<::Entity as EntityTrait>::Column as FromStr>::from_str( + &b.to_string(), + ) + .unwrap_or_else(|_| panic!("Failed at mapping string to column B:3")); + let column_c = + <<::Entity as EntityTrait>::Column as FromStr>::from_str( + &c.to_string(), + ) + .unwrap_or_else(|_| panic!("Failed at mapping string to column C:3")); + ValueTuple::Three( + model.get(column_a), + model.get(column_b), + model.get(column_c), + ) + } + } +} + +fn prepare_condition(col: &Identity, keys: &[ValueTuple]) -> Condition +where + M: ModelTrait, +{ + match col { + Identity::Unary(column_a) => { + let column_a: ::Column = + <::Column as FromStr>::from_str(&column_a.to_string()) + .unwrap_or_else(|_| panic!("Failed at mapping string to column *A:1")); + Condition::all().add(ColumnTrait::is_in( + &column_a, + keys.iter().cloned().flatten(), + )) + } + Identity::Binary(column_a, column_b) => { + let column_a: ::Column = + <::Column as FromStr>::from_str(&column_a.to_string()) + .unwrap_or_else(|_| panic!("Failed at mapping string to column *A:2")); + let column_b: ::Column = + <::Column as FromStr>::from_str(&column_b.to_string()) + .unwrap_or_else(|_| panic!("Failed at mapping string to column *B:2")); + Condition::all().add( + Expr::tuple([ + SimpleExpr::Column(column_a.into_column_ref()), + SimpleExpr::Column(column_b.into_column_ref()), + ]) + .in_tuples(keys.iter().cloned()), + ) + } + Identity::Ternary(column_a, column_b, column_c) => { + let column_a: ::Column = + <::Column as FromStr>::from_str(&column_a.to_string()) + .unwrap_or_else(|_| panic!("Failed at mapping string to column *A:3")); + let column_b: ::Column = + <::Column as FromStr>::from_str(&column_b.to_string()) + .unwrap_or_else(|_| panic!("Failed at mapping string to column *B:3")); + let column_c: ::Column = + <::Column as FromStr>::from_str(&column_c.to_string()) + .unwrap_or_else(|_| panic!("Failed at mapping string to column *C:3")); + Condition::all().add( + Expr::tuple([ + SimpleExpr::Column(column_a.into_column_ref()), + SimpleExpr::Column(column_b.into_column_ref()), + SimpleExpr::Column(column_c.into_column_ref()), + ]) + .in_tuples(keys.iter().cloned()), + ) + } + } +} + +#[cfg(test)] +mod tests { + #[tokio::test] + async fn test_load_one() { + use crate::{ + entity::prelude::*, tests_cfg::*, DbBackend, IntoMockRow, LoaderTrait, MockDatabase, + }; + + let db = MockDatabase::new(DbBackend::Postgres) + .append_query_results(vec![vec![ + cake::Model { + id: 1, + name: "New York Cheese".to_owned(), + } + .into_mock_row(), + cake::Model { + id: 2, + name: "London Cheese".to_owned(), + } + .into_mock_row(), + ]]) + .into_connection(); + + let fruits = vec![fruit::Model { + id: 1, + name: "Apple".to_owned(), + cake_id: Some(1), + }]; + + let cakes = fruits + .load_one(cake::Entity::find(), &db) + .await + .expect("Should return something"); + + assert_eq!( + cakes, + vec![Some(cake::Model { + id: 1, + name: "New York Cheese".to_owned(), + })] + ); + } + + #[tokio::test] + async fn test_load_many() { + use crate::{ + entity::prelude::*, tests_cfg::*, DbBackend, IntoMockRow, LoaderTrait, MockDatabase, + }; + + let db = MockDatabase::new(DbBackend::Postgres) + .append_query_results(vec![vec![fruit::Model { + id: 1, + name: "Apple".to_owned(), + cake_id: Some(1), + } + .into_mock_row()]]) + .into_connection(); + + let cakes = vec![ + cake::Model { + id: 1, + name: "New York Cheese".to_owned(), + }, + cake::Model { + id: 2, + name: "London Cheese".to_owned(), + }, + ]; + + let fruits = cakes + .load_many(fruit::Entity::find(), &db) + .await + .expect("Should return something"); + + assert_eq!( + fruits, + vec![ + vec![fruit::Model { + id: 1, + name: "Apple".to_owned(), + cake_id: Some(1), + }], + vec![] + ] + ); + } +} diff --git a/src/query/mod.rs b/src/query/mod.rs index 2de0e7908..559eba176 100644 --- a/src/query/mod.rs +++ b/src/query/mod.rs @@ -5,6 +5,7 @@ mod insert; mod join; #[cfg(feature = "with-json")] mod json; +mod loader; mod select; mod traits; mod update; @@ -17,6 +18,7 @@ pub use insert::*; pub use join::*; #[cfg(feature = "with-json")] pub use json::*; +pub use loader::*; pub use select::*; pub use traits::*; pub use update::*; diff --git a/tests/loader_tests.rs b/tests/loader_tests.rs new file mode 100644 index 000000000..3d38922a2 --- /dev/null +++ b/tests/loader_tests.rs @@ -0,0 +1,232 @@ +pub mod common; + +pub use common::{bakery_chain::*, setup::*, TestContext}; +pub use sea_orm::{entity::*, query::*, DbErr, FromQueryResult}; + +#[sea_orm_macros::test] +#[cfg(any( + feature = "sqlx-mysql", + feature = "sqlx-sqlite", + feature = "sqlx-postgres" +))] +async fn loader_load_one() -> Result<(), DbErr> { + let ctx = TestContext::new("loader_test_load_one").await; + create_tables(&ctx.db).await?; + + let bakery = bakery::ActiveModel { + name: Set("SeaSide Bakery".to_owned()), + profit_margin: Set(10.4), + ..Default::default() + } + .insert(&ctx.db) + .await + .expect("could not insert bakery"); + + let baker_1 = baker::ActiveModel { + name: Set("Baker 1".to_owned()), + contact_details: Set(serde_json::json!({ + "mobile": "+61424000000", + "home": "0395555555", + "address": "12 Test St, Testville, Vic, Australia" + })), + bakery_id: Set(Some(bakery.id)), + ..Default::default() + } + .insert(&ctx.db) + .await + .expect("could not insert baker"); + + let baker_2 = baker::ActiveModel { + name: Set("Baker 2".to_owned()), + contact_details: Set(serde_json::json!({})), + bakery_id: Set(None), + ..Default::default() + } + .insert(&ctx.db) + .await + .expect("could not insert baker"); + + let bakers = baker::Entity::find() + .all(&ctx.db) + .await + .expect("Should load bakers"); + + let bakeries = bakers + .load_one(bakery::Entity::find(), &ctx.db) + .await + .expect("Should load bakeries"); + + assert_eq!(bakers, vec![baker_1, baker_2]); + + assert_eq!(bakeries, vec![Some(bakery), None]); + + Ok(()) +} + +#[sea_orm_macros::test] +#[cfg(any( + feature = "sqlx-mysql", + feature = "sqlx-sqlite", + feature = "sqlx-postgres" +))] +async fn loader_load_one_complex() -> Result<(), DbErr> { + let ctx = TestContext::new("loader_test_load_one").await; + create_tables(&ctx.db).await?; + + let bakery = bakery::ActiveModel { + name: Set("SeaSide Bakery".to_owned()), + profit_margin: Set(10.4), + ..Default::default() + } + .insert(&ctx.db) + .await + .expect("could not insert bakery"); + + let baker_1 = baker::ActiveModel { + name: Set("Baker 1".to_owned()), + contact_details: Set(serde_json::json!({ + "mobile": "+61424000000", + "home": "0395555555", + "address": "12 Test St, Testville, Vic, Australia" + })), + bakery_id: Set(Some(bakery.id)), + ..Default::default() + } + .insert(&ctx.db) + .await + .expect("could not insert baker"); + + let baker_2 = baker::ActiveModel { + name: Set("Baker 2".to_owned()), + contact_details: Set(serde_json::json!({})), + bakery_id: Set(Some(bakery.id)), + ..Default::default() + } + .insert(&ctx.db) + .await + .expect("could not insert baker"); + + let bakers = baker::Entity::find() + .all(&ctx.db) + .await + .expect("Should load bakers"); + + let bakeries = bakers + .load_one(bakery::Entity::find(), &ctx.db) + .await + .expect("Should load bakeries"); + + assert_eq!(bakers, vec![baker_1, baker_2]); + + assert_eq!(bakeries, vec![Some(bakery.clone()), Some(bakery.clone())]); + + Ok(()) +} + +#[sea_orm_macros::test] +#[cfg(any( + feature = "sqlx-mysql", + feature = "sqlx-sqlite", + feature = "sqlx-postgres" +))] +async fn loader_load_many() -> Result<(), DbErr> { + let ctx = TestContext::new("loader_test_load_many").await; + create_tables(&ctx.db).await?; + + let bakery_1 = bakery::ActiveModel { + name: Set("SeaSide Bakery".to_owned()), + profit_margin: Set(10.4), + ..Default::default() + } + .insert(&ctx.db) + .await + .expect("could not insert bakery"); + + let bakery_2 = bakery::ActiveModel { + name: Set("Offshore Bakery".to_owned()), + profit_margin: Set(10.4), + ..Default::default() + } + .insert(&ctx.db) + .await + .expect("could not insert bakery"); + + let baker_1 = baker::ActiveModel { + name: Set("Baker 1".to_owned()), + contact_details: Set(serde_json::json!({ + "mobile": "+61424000000", + "home": "0395555555", + "address": "12 Test St, Testville, Vic, Australia" + })), + bakery_id: Set(Some(bakery_1.id)), + ..Default::default() + } + .insert(&ctx.db) + .await + .expect("could not insert baker"); + + let baker_2 = baker::ActiveModel { + name: Set("Baker 2".to_owned()), + contact_details: Set(serde_json::json!({})), + bakery_id: Set(Some(bakery_1.id)), + ..Default::default() + } + .insert(&ctx.db) + .await + .expect("could not insert baker"); + + let baker_3 = baker::ActiveModel { + name: Set("John".to_owned()), + contact_details: Set(serde_json::json!({})), + bakery_id: Set(Some(bakery_2.id)), + ..Default::default() + } + .insert(&ctx.db) + .await + .expect("could not insert baker"); + + let baker_4 = baker::ActiveModel { + name: Set("Baker 4".to_owned()), + contact_details: Set(serde_json::json!({})), + bakery_id: Set(Some(bakery_2.id)), + ..Default::default() + } + .insert(&ctx.db) + .await + .expect("could not insert baker"); + + let bakeries = bakery::Entity::find() + .all(&ctx.db) + .await + .expect("Should load bakeries"); + + let bakers = bakeries + .load_many( + baker::Entity::find().filter(baker::Column::Name.like("Baker%")), + &ctx.db, + ) + .await + .expect("Should load bakers"); + + println!("A: {:?}", bakers); + println!("B: {:?}", bakeries); + + assert_eq!(bakeries, vec![bakery_1, bakery_2]); + + assert_eq!( + bakers, + vec![ + vec![baker_1.clone(), baker_2.clone()], + vec![baker_4.clone()] + ] + ); + + let bakers = bakeries + .load_many(baker::Entity::find(), &ctx.db) + .await + .expect("Should load bakers"); + + assert_eq!(bakers, vec![vec![baker_1, baker_2], vec![baker_3, baker_4]]); + + Ok(()) +}