From 06708271489d67a8de362413e5b89f1110063df6 Mon Sep 17 00:00:00 2001 From: Chris Tsang Date: Thu, 2 Feb 2023 07:46:38 +0800 Subject: [PATCH 1/3] Improve test cases --- src/entity/relation.rs | 2 +- src/query/loader.rs | 104 +++++++++++++++-------------------- tests/loader_tests.rs | 119 +++++++++++++++++++++++++++++------------ 3 files changed, 130 insertions(+), 95 deletions(-) diff --git a/src/entity/relation.rs b/src/entity/relation.rs index ad291f0f7..cae517888 100644 --- a/src/entity/relation.rs +++ b/src/entity/relation.rs @@ -7,7 +7,7 @@ use sea_query::{ use std::fmt::Debug; /// Defines the type of relationship -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq, Eq)] pub enum RelationType { /// An Entity has one relationship HasOne, diff --git a/src/query/loader.rs b/src/query/loader.rs index 5d9dd37dc..5beee927a 100644 --- a/src/query/loader.rs +++ b/src/query/loader.rs @@ -1,9 +1,9 @@ use crate::{ - error::*, ColumnTrait, Condition, ConnectionTrait, DbErr, EntityTrait, Identity, ModelTrait, - QueryFilter, Related, RelationType, Select, + error::*, Condition, ConnectionTrait, DbErr, EntityTrait, Identity, ModelTrait, QueryFilter, + Related, RelationType, Select, }; use async_trait::async_trait; -use sea_query::{Expr, IntoColumnRef, SimpleExpr, ValueTuple}; +use sea_query::{ColumnRef, DynIden, Expr, IntoColumnRef, SimpleExpr, TableRef, ValueTuple}; use std::{collections::HashMap, str::FromStr}; /// A trait for basic Dataloader @@ -71,12 +71,13 @@ where R::Model: Send + Sync, <::Model as ModelTrait>::Entity: Related, { + // we verify that is HasOne relation + if <<::Model as ModelTrait>::Entity as Related>::via().is_some() { + return Err(type_err("Relation is ManytoMany instead of HasOne")); + } let rel_def = <<::Model as ModelTrait>::Entity as Related>::to(); - - // we verify that is has_one relation - match (rel_def).rel_type { - RelationType::HasOne => (), - RelationType::HasMany => return Err(type_err("Relation is HasMany instead of HasOne")), + if rel_def.rel_type == RelationType::HasMany { + return Err(type_err("Relation is HasMany instead of HasOne")); } let keys: Vec = self @@ -84,7 +85,7 @@ where .map(|model: &M| extract_key(&rel_def.from_col, model)) .collect(); - let condition = prepare_condition::<::Model>(&rel_def.to_col, &keys); + let condition = prepare_condition(&rel_def.to_tbl, &rel_def.to_col, &keys); let stmt = as QueryFilter>::filter(stmt, condition); @@ -119,12 +120,14 @@ where R::Model: Send + Sync, <::Model as ModelTrait>::Entity: Related, { - let rel_def = <<::Model as ModelTrait>::Entity as Related>::to(); + // we verify that is HasMany relation - // we verify that is has_many relation - match (rel_def).rel_type { - RelationType::HasMany => (), - RelationType::HasOne => return Err(type_err("Relation is HasOne instead of HasMany")), + if <<::Model as ModelTrait>::Entity as Related>::via().is_some() { + return Err(type_err("Relation is ManyToMany instead of HasMany")); + } + let rel_def = <<::Model as ModelTrait>::Entity as Related>::to(); + if rel_def.rel_type == RelationType::HasOne { + return Err(type_err("Relation is HasOne instead of HasMany")); } let keys: Vec = self @@ -132,7 +135,7 @@ where .map(|model: &M| extract_key(&rel_def.from_col, model)) .collect(); - let condition = prepare_condition::<::Model>(&rel_def.to_col, &keys); + let condition = prepare_condition(&rel_def.to_tbl, &rel_def.to_col, &keys); let stmt = as QueryFilter>::filter(stmt, condition); @@ -222,54 +225,35 @@ where } } -fn prepare_condition(col: &Identity, keys: &[ValueTuple]) -> Condition -where - M: ModelTrait, -{ +fn prepare_condition(table: &TableRef, col: &Identity, keys: &[ValueTuple]) -> Condition { match col { Identity::Unary(column_a) => { - let column_a: ::Column = - <::Column as FromStr>::from_str(&column_a.to_string()) - .unwrap_or_else(|_| panic!("Failed at mapping string to column *A:1")); - Condition::all().add(ColumnTrait::is_in( - &column_a, - keys.iter().cloned().flatten(), - )) - } - Identity::Binary(column_a, column_b) => { - let column_a: ::Column = - <::Column as FromStr>::from_str(&column_a.to_string()) - .unwrap_or_else(|_| panic!("Failed at mapping string to column *A:2")); - let column_b: ::Column = - <::Column as FromStr>::from_str(&column_b.to_string()) - .unwrap_or_else(|_| panic!("Failed at mapping string to column *B:2")); - Condition::all().add( - Expr::tuple([ - SimpleExpr::Column(column_a.into_column_ref()), - SimpleExpr::Column(column_b.into_column_ref()), - ]) - .in_tuples(keys.iter().cloned()), - ) - } - Identity::Ternary(column_a, column_b, column_c) => { - let column_a: ::Column = - <::Column as FromStr>::from_str(&column_a.to_string()) - .unwrap_or_else(|_| panic!("Failed at mapping string to column *A:3")); - let column_b: ::Column = - <::Column as FromStr>::from_str(&column_b.to_string()) - .unwrap_or_else(|_| panic!("Failed at mapping string to column *B:3")); - let column_c: ::Column = - <::Column as FromStr>::from_str(&column_c.to_string()) - .unwrap_or_else(|_| panic!("Failed at mapping string to column *C:3")); - Condition::all().add( - Expr::tuple([ - SimpleExpr::Column(column_a.into_column_ref()), - SimpleExpr::Column(column_b.into_column_ref()), - SimpleExpr::Column(column_c.into_column_ref()), - ]) - .in_tuples(keys.iter().cloned()), - ) + let column_a = table_column(table, column_a); + Condition::all().add(Expr::col(column_a).is_in(keys.iter().cloned().flatten())) } + Identity::Binary(column_a, column_b) => Condition::all().add( + Expr::tuple([ + SimpleExpr::Column(table_column(table, column_a)), + SimpleExpr::Column(table_column(table, column_b)), + ]) + .in_tuples(keys.iter().cloned()), + ), + Identity::Ternary(column_a, column_b, column_c) => Condition::all().add( + Expr::tuple([ + SimpleExpr::Column(table_column(table, column_a)), + SimpleExpr::Column(table_column(table, column_b)), + SimpleExpr::Column(table_column(table, column_c)), + ]) + .in_tuples(keys.iter().cloned()), + ), + } +} + +fn table_column(tbl: &TableRef, col: &DynIden) -> ColumnRef { + match tbl.to_owned() { + TableRef::Table(tbl) => (tbl, col.clone()).into_column_ref(), + TableRef::SchemaTable(sch, tbl) => (sch, tbl, col.clone()).into_column_ref(), + val => unimplemented!("Unsupported TableRef {val:?}"), } } diff --git a/tests/loader_tests.rs b/tests/loader_tests.rs index 0026fc85c..ba22d6188 100644 --- a/tests/loader_tests.rs +++ b/tests/loader_tests.rs @@ -24,18 +24,10 @@ async fn loader_load_one() -> Result<(), DbErr> { ..Default::default() } .insert(&ctx.db) - .await - .expect("could not insert baker"); - - let bakers = baker::Entity::find() - .all(&ctx.db) - .await - .expect("Should load bakers"); + .await?; - let bakeries = bakers - .load_one(bakery::Entity::find(), &ctx.db) - .await - .expect("Should load bakeries"); + let bakers = baker::Entity::find().all(&ctx.db).await?; + let bakeries = bakers.load_one(bakery::Entity::find(), &ctx.db).await?; assert_eq!(bakers, [baker_1, baker_2]); @@ -59,15 +51,8 @@ async fn loader_load_one_complex() -> Result<(), DbErr> { let baker_1 = insert_baker(&ctx.db, "Baker 1", bakery.id).await?; let baker_2 = insert_baker(&ctx.db, "Baker 2", bakery.id).await?; - let bakers = baker::Entity::find() - .all(&ctx.db) - .await - .expect("Should load bakers"); - - let bakeries = bakers - .load_one(bakery::Entity::find(), &ctx.db) - .await - .expect("Should load bakeries"); + let bakers = baker::Entity::find().all(&ctx.db).await?; + let bakeries = bakers.load_one(bakery::Entity::find(), &ctx.db).await?; assert_eq!(bakers, [baker_1, baker_2]); @@ -95,23 +80,19 @@ async fn loader_load_many() -> Result<(), DbErr> { let baker_3 = insert_baker(&ctx.db, "John", bakery_2.id).await?; let baker_4 = insert_baker(&ctx.db, "Baker 4", bakery_2.id).await?; - let bakeries = bakery::Entity::find() - .all(&ctx.db) - .await - .expect("Should load bakeries"); + let bakeries = bakery::Entity::find().all(&ctx.db).await?; let bakers = bakeries .load_many( baker::Entity::find().filter(baker::Column::Name.like("Baker%")), &ctx.db, ) - .await - .expect("Should load bakers"); + .await?; println!("A: {bakers:?}"); println!("B: {bakeries:?}"); - assert_eq!(bakeries, [bakery_1, bakery_2]); + assert_eq!(bakeries, [bakery_1.clone(), bakery_2.clone()]); assert_eq!( bakers, @@ -121,12 +102,32 @@ async fn loader_load_many() -> Result<(), DbErr> { ] ); - let bakers = bakeries - .load_many(baker::Entity::find(), &ctx.db) - .await - .expect("Should load bakers"); + let bakers = bakeries.load_many(baker::Entity::find(), &ctx.db).await?; + + assert_eq!( + bakers, + [ + [baker_1.clone(), baker_2.clone()], + [baker_3.clone(), baker_4.clone()] + ] + ); + + // now, start from baker + + let bakers = baker::Entity::find().all(&ctx.db).await?; + let bakeries = bakers.load_one(bakery::Entity::find(), &ctx.db).await?; - assert_eq!(bakers, [[baker_1, baker_2], [baker_3, baker_4]]); + // note that two bakers share the same bakery + assert_eq!(bakers, [baker_1, baker_2, baker_3, baker_4]); + assert_eq!( + bakeries, + [ + Some(bakery_1.clone()), + Some(bakery_1), + Some(bakery_2.clone()), + Some(bakery_2) + ] + ); Ok(()) } @@ -137,8 +138,8 @@ async fn loader_load_many() -> Result<(), DbErr> { feature = "sqlx-sqlite", feature = "sqlx-postgres" ))] -async fn loader_load_many_many() -> Result<(), DbErr> { - let ctx = TestContext::new("loader_test_load_many_many").await; +async fn loader_load_many_multi() -> Result<(), DbErr> { + let ctx = TestContext::new("loader_test_load_many_multi").await; create_tables(&ctx.db).await?; let bakery_1 = insert_bakery(&ctx.db, "SeaSide Bakery").await?; @@ -167,6 +168,43 @@ async fn loader_load_many_many() -> Result<(), DbErr> { Ok(()) } +#[ignore] +#[sea_orm_macros::test] +#[cfg(any( + feature = "sqlx-mysql", + feature = "sqlx-sqlite", + feature = "sqlx-postgres" +))] +async fn loader_load_many_to_many() -> Result<(), DbErr> { + let ctx = TestContext::new("loader_test_load_many_to_many").await; + create_tables(&ctx.db).await?; + + let bakery_1 = insert_bakery(&ctx.db, "SeaSide Bakery").await?; + + let baker_1 = insert_baker(&ctx.db, "Jane", bakery_1.id).await?; + let baker_2 = insert_baker(&ctx.db, "Peter", bakery_1.id).await?; + + let cake_1 = insert_cake(&ctx.db, "Cheesecake", bakery_1.id).await?; + let cake_2 = insert_cake(&ctx.db, "Chocolate", bakery_1.id).await?; + let cake_3 = insert_cake(&ctx.db, "Chiffon", bakery_1.id).await?; + + insert_cake_baker(&ctx.db, baker_1.id, cake_1.id).await?; + insert_cake_baker(&ctx.db, baker_1.id, cake_2.id).await?; + insert_cake_baker(&ctx.db, baker_2.id, cake_2.id).await?; + insert_cake_baker(&ctx.db, baker_2.id, cake_3.id).await?; + + let bakers = baker::Entity::find().all(&ctx.db).await?; + let cakes = bakers.load_many(cake::Entity::find(), &ctx.db).await?; + + println!("{bakers:?}"); + println!("{cakes:?}"); + + assert_eq!(bakers, [baker_1, baker_2]); + assert_eq!(cakes, [vec![cake_1, cake_2.clone()], vec![cake_2, cake_3]]); + + Ok(()) +} + pub async fn insert_bakery(db: &DbConn, name: &str) -> Result { bakery::ActiveModel { name: Set(name.to_owned()), @@ -199,3 +237,16 @@ pub async fn insert_cake(db: &DbConn, name: &str, bakery_id: i32) -> Result Result { + cakes_bakers::ActiveModel { + cake_id: Set(cake_id), + baker_id: Set(baker_id), + } + .insert(db) + .await +} From 83c0732395ab083e809e3e5d41f92d8a03551a65 Mon Sep 17 00:00:00 2001 From: Chris Tsang Date: Thu, 2 Feb 2023 09:38:30 +0800 Subject: [PATCH 2/3] load_many_to_many --- src/query/loader.rs | 142 ++++++++++++++++++++++++++++++++++++++++-- tests/loader_tests.rs | 127 ++++++++++++++++++------------------- 2 files changed, 198 insertions(+), 71 deletions(-) diff --git a/src/query/loader.rs b/src/query/loader.rs index 5beee927a..56f263de5 100644 --- a/src/query/loader.rs +++ b/src/query/loader.rs @@ -20,13 +20,28 @@ pub trait LoaderTrait { R::Model: Send + Sync, <::Model as ModelTrait>::Entity: Related; - /// Used to eager load has_many relations + /// Used to eager load has_many relations async fn load_many(&self, stmt: Select, db: &C) -> Result>, DbErr> where C: ConnectionTrait, R: EntityTrait, R::Model: Send + Sync, <::Model as ModelTrait>::Entity: Related; + + /// Used to eager load many_to_many relations + async fn load_many_to_many( + &self, + stmt: Select, + via: V, + db: &C, + ) -> Result>, DbErr> + where + C: ConnectionTrait, + R: EntityTrait, + R::Model: Send + Sync, + V: EntityTrait, + V::Model: Send + Sync, + <::Model as ModelTrait>::Entity: Related; } #[async_trait] @@ -55,6 +70,23 @@ where { self.as_slice().load_many(stmt, db).await } + + async fn load_many_to_many( + &self, + stmt: Select, + via: V, + db: &C, + ) -> Result>, DbErr> + where + C: ConnectionTrait, + R: EntityTrait, + R::Model: Send + Sync, + V: EntityTrait, + V::Model: Send + Sync, + <::Model as ModelTrait>::Entity: Related, + { + self.as_slice().load_many_to_many(stmt, via, db).await + } } #[async_trait] @@ -73,11 +105,11 @@ where { // we verify that is HasOne relation if <<::Model as ModelTrait>::Entity as Related>::via().is_some() { - return Err(type_err("Relation is ManytoMany instead of HasOne")); + return Err(query_err("Relation is ManytoMany instead of HasOne")); } let rel_def = <<::Model as ModelTrait>::Entity as Related>::to(); if rel_def.rel_type == RelationType::HasMany { - return Err(type_err("Relation is HasMany instead of HasOne")); + return Err(query_err("Relation is HasMany instead of HasOne")); } let keys: Vec = self @@ -123,11 +155,11 @@ where // we verify that is HasMany relation if <<::Model as ModelTrait>::Entity as Related>::via().is_some() { - return Err(type_err("Relation is ManyToMany instead of HasMany")); + return Err(query_err("Relation is ManyToMany instead of HasMany")); } let rel_def = <<::Model as ModelTrait>::Entity as Related>::to(); if rel_def.rel_type == RelationType::HasOne { - return Err(type_err("Relation is HasOne instead of HasMany")); + return Err(query_err("Relation is HasOne instead of HasMany")); } let keys: Vec = self @@ -172,6 +204,106 @@ where Ok(result) } + + async fn load_many_to_many( + &self, + stmt: Select, + via: V, + db: &C, + ) -> Result>, DbErr> + where + C: ConnectionTrait, + R: EntityTrait, + R::Model: Send + Sync, + V: EntityTrait, + V::Model: Send + Sync, + <::Model as ModelTrait>::Entity: Related, + { + if let Some(via_rel) = + <<::Model as ModelTrait>::Entity as Related>::via() + { + let rel_def = + <<::Model as ModelTrait>::Entity as Related>::to(); + if rel_def.rel_type != RelationType::HasOne { + return Err(query_err("Relation to is not HasOne")); + } + + if !cmp_table_ref(&via_rel.to_tbl, &via.table_ref()) { + return Err(query_err(format!( + "The given via Entity is incorrect: expected: {:?}, given: {:?}", + via_rel.to_tbl, + via.table_ref() + ))); + } + + let pkeys: Vec = self + .iter() + .map(|model: &M| extract_key(&via_rel.from_col, model)) + .collect(); + + // Map of M::PK -> Vec + let mut keymap: HashMap> = Default::default(); + + let keys: Vec = { + let condition = prepare_condition(&via_rel.to_tbl, &via_rel.to_col, &pkeys); + let stmt = V::find().filter(condition); + let data = stmt.all(db).await?; + data.into_iter().for_each(|model| { + let pk = format!("{:?}", extract_key(&via_rel.to_col, &model)); + let entry = keymap.entry(pk).or_default(); + + let fk = extract_key(&rel_def.from_col, &model); + entry.push(fk); + }); + + keymap.values().flatten().cloned().collect() + }; + + let condition = prepare_condition(&rel_def.to_tbl, &rel_def.to_col, &keys); + + let stmt = as QueryFilter>::filter(stmt, condition); + + let data = stmt.all(db).await?; + // Map of R::PK -> R::Model + let data: HashMap::Model> = data + .into_iter() + .map(|model| { + let key = format!("{:?}", extract_key(&rel_def.to_col, &model)); + (key, model) + }) + .collect(); + + let result: Vec> = pkeys + .into_iter() + .map(|pkey| { + let fkeys = keymap + .get(&format!("{pkey:?}")) + .cloned() + .unwrap_or_default(); + + let models: Vec<_> = fkeys + .into_iter() + .map(|fkey| { + data.get(&format!("{fkey:?}")) + .cloned() + .expect("Failed at finding key on hashmap") + }) + .collect(); + + models + }) + .collect(); + + Ok(result) + } else { + return Err(query_err("Relation is not ManyToMany")); + } + } +} + +fn cmp_table_ref(left: &TableRef, right: &TableRef) -> bool { + // not ideal; but + format!("{left:?}") == format!("{right:?}") } fn extract_key(target_col: &Identity, model: &Model) -> ValueTuple diff --git a/tests/loader_tests.rs b/tests/loader_tests.rs index ba22d6188..bc0b5467a 100644 --- a/tests/loader_tests.rs +++ b/tests/loader_tests.rs @@ -13,12 +13,12 @@ async fn loader_load_one() -> Result<(), DbErr> { let ctx = TestContext::new("loader_test_load_one").await; create_tables(&ctx.db).await?; - let bakery = insert_bakery(&ctx.db, "SeaSide Bakery").await?; + let bakery_0 = insert_bakery(&ctx.db, "SeaSide Bakery").await?; - let baker_1 = insert_baker(&ctx.db, "Baker 1", bakery.id).await?; - - let baker_2 = baker::ActiveModel { - name: Set("Baker 2".to_owned()), + let baker_1 = insert_baker(&ctx.db, "Baker 1", bakery_0.id).await?; + let baker_2 = insert_baker(&ctx.db, "Baker 2", bakery_0.id).await?; + let baker_3 = baker::ActiveModel { + name: Set("Baker 3".to_owned()), contact_details: Set(serde_json::json!({})), bakery_id: Set(None), ..Default::default() @@ -29,34 +29,8 @@ async fn loader_load_one() -> Result<(), DbErr> { let bakers = baker::Entity::find().all(&ctx.db).await?; let bakeries = bakers.load_one(bakery::Entity::find(), &ctx.db).await?; - assert_eq!(bakers, [baker_1, baker_2]); - - assert_eq!(bakeries, [Some(bakery), None]); - - Ok(()) -} - -#[sea_orm_macros::test] -#[cfg(any( - feature = "sqlx-mysql", - feature = "sqlx-sqlite", - feature = "sqlx-postgres" -))] -async fn loader_load_one_complex() -> Result<(), DbErr> { - let ctx = TestContext::new("loader_test_load_one_complex").await; - create_tables(&ctx.db).await?; - - let bakery = insert_bakery(&ctx.db, "SeaSide Bakery").await?; - - let baker_1 = insert_baker(&ctx.db, "Baker 1", bakery.id).await?; - let baker_2 = insert_baker(&ctx.db, "Baker 2", bakery.id).await?; - - let bakers = baker::Entity::find().all(&ctx.db).await?; - let bakeries = bakers.load_one(bakery::Entity::find(), &ctx.db).await?; - - assert_eq!(bakers, [baker_1, baker_2]); - - assert_eq!(bakeries, [Some(bakery.clone()), Some(bakery.clone())]); + assert_eq!(bakers, [baker_1, baker_2, baker_3]); + assert_eq!(bakeries, [Some(bakery_0.clone()), Some(bakery_0), None]); Ok(()) } @@ -81,6 +55,18 @@ async fn loader_load_many() -> Result<(), DbErr> { let baker_4 = insert_baker(&ctx.db, "Baker 4", bakery_2.id).await?; let bakeries = bakery::Entity::find().all(&ctx.db).await?; + let bakers = bakeries.load_many(baker::Entity::find(), &ctx.db).await?; + + assert_eq!(bakeries, [bakery_1.clone(), bakery_2.clone()]); + assert_eq!( + bakers, + [ + [baker_1.clone(), baker_2.clone()], + [baker_3.clone(), baker_4.clone()] + ] + ); + + // load bakers again but with additional condition let bakers = bakeries .load_many( @@ -89,11 +75,6 @@ async fn loader_load_many() -> Result<(), DbErr> { ) .await?; - println!("A: {bakers:?}"); - println!("B: {bakeries:?}"); - - assert_eq!(bakeries, [bakery_1.clone(), bakery_2.clone()]); - assert_eq!( bakers, [ @@ -102,16 +83,6 @@ async fn loader_load_many() -> Result<(), DbErr> { ] ); - let bakers = bakeries.load_many(baker::Entity::find(), &ctx.db).await?; - - assert_eq!( - bakers, - [ - [baker_1.clone(), baker_2.clone()], - [baker_3.clone(), baker_4.clone()] - ] - ); - // now, start from baker let bakers = baker::Entity::find().all(&ctx.db).await?; @@ -149,18 +120,15 @@ async fn loader_load_many_multi() -> Result<(), DbErr> { let baker_2 = insert_baker(&ctx.db, "Jane", bakery_1.id).await?; let baker_3 = insert_baker(&ctx.db, "Peter", bakery_2.id).await?; - let cake_1 = insert_cake(&ctx.db, "Cheesecake", bakery_1.id).await?; - let cake_2 = insert_cake(&ctx.db, "Chocolate", bakery_2.id).await?; - let cake_3 = insert_cake(&ctx.db, "Chiffon", bakery_2.id).await?; + let cake_1 = insert_cake(&ctx.db, "Cheesecake", Some(bakery_1.id)).await?; + let cake_2 = insert_cake(&ctx.db, "Chocolate", Some(bakery_2.id)).await?; + let cake_3 = insert_cake(&ctx.db, "Chiffon", Some(bakery_2.id)).await?; + let _cake_4 = insert_cake(&ctx.db, "Apple Pie", None).await?; // no one makes apple pie let bakeries = bakery::Entity::find().all(&ctx.db).await?; let bakers = bakeries.load_many(baker::Entity::find(), &ctx.db).await?; let cakes = bakeries.load_many(cake::Entity::find(), &ctx.db).await?; - println!("{bakers:?}"); - println!("{bakeries:?}"); - println!("{cakes:?}"); - assert_eq!(bakeries, [bakery_1, bakery_2]); assert_eq!(bakers, [vec![baker_1, baker_2], vec![baker_3]]); assert_eq!(cakes, [vec![cake_1], vec![cake_2, cake_3]]); @@ -168,7 +136,6 @@ async fn loader_load_many_multi() -> Result<(), DbErr> { Ok(()) } -#[ignore] #[sea_orm_macros::test] #[cfg(any( feature = "sqlx-mysql", @@ -184,9 +151,10 @@ async fn loader_load_many_to_many() -> Result<(), DbErr> { let baker_1 = insert_baker(&ctx.db, "Jane", bakery_1.id).await?; let baker_2 = insert_baker(&ctx.db, "Peter", bakery_1.id).await?; - let cake_1 = insert_cake(&ctx.db, "Cheesecake", bakery_1.id).await?; - let cake_2 = insert_cake(&ctx.db, "Chocolate", bakery_1.id).await?; - let cake_3 = insert_cake(&ctx.db, "Chiffon", bakery_1.id).await?; + let cake_1 = insert_cake(&ctx.db, "Cheesecake", None).await?; + let cake_2 = insert_cake(&ctx.db, "Chocolate", None).await?; + let cake_3 = insert_cake(&ctx.db, "Chiffon", None).await?; + let cake_4 = insert_cake(&ctx.db, "Apple Pie", None).await?; // no one makes apple pie insert_cake_baker(&ctx.db, baker_1.id, cake_1.id).await?; insert_cake_baker(&ctx.db, baker_1.id, cake_2.id).await?; @@ -194,13 +162,36 @@ async fn loader_load_many_to_many() -> Result<(), DbErr> { insert_cake_baker(&ctx.db, baker_2.id, cake_3.id).await?; let bakers = baker::Entity::find().all(&ctx.db).await?; - let cakes = bakers.load_many(cake::Entity::find(), &ctx.db).await?; + let cakes = bakers + .load_many_to_many(cake::Entity::find(), cakes_bakers::Entity, &ctx.db) + .await?; - println!("{bakers:?}"); - println!("{cakes:?}"); + assert_eq!(bakers, [baker_1.clone(), baker_2.clone()]); + assert_eq!( + cakes, + [ + vec![cake_1.clone(), cake_2.clone()], + vec![cake_2.clone(), cake_3.clone()] + ] + ); - assert_eq!(bakers, [baker_1, baker_2]); - assert_eq!(cakes, [vec![cake_1, cake_2.clone()], vec![cake_2, cake_3]]); + // now, start again from cakes + + let cakes = cake::Entity::find().all(&ctx.db).await?; + let bakers = cakes + .load_many_to_many(baker::Entity::find(), cakes_bakers::Entity, &ctx.db) + .await?; + + assert_eq!(cakes, [cake_1, cake_2, cake_3, cake_4]); + assert_eq!( + bakers, + [ + vec![baker_1.clone()], + vec![baker_1.clone(), baker_2.clone()], + vec![baker_2.clone()], + vec![] + ] + ); Ok(()) } @@ -226,12 +217,16 @@ pub async fn insert_baker(db: &DbConn, name: &str, bakery_id: i32) -> Result Result { +pub async fn insert_cake( + db: &DbConn, + name: &str, + bakery_id: Option, +) -> Result { cake::ActiveModel { name: Set(name.to_owned()), price: Set(rust_decimal::Decimal::ONE), gluten_free: Set(false), - bakery_id: Set(Some(bakery_id)), + bakery_id: Set(bakery_id), ..Default::default() } .insert(db) From 56e4b4337b4298f718d93043f203acceae574392 Mon Sep 17 00:00:00 2001 From: Chris Tsang Date: Thu, 2 Feb 2023 11:21:00 +0800 Subject: [PATCH 3/3] Improve API & Example --- examples/basic/src/select.rs | 45 +++++++++++++++++++++-- src/query/loader.rs | 71 +++++++++++++++++++++++++----------- tests/loader_tests.rs | 25 +++++++++---- 3 files changed, 110 insertions(+), 31 deletions(-) diff --git a/examples/basic/src/select.rs b/examples/basic/src/select.rs index 220920ce3..611226a21 100644 --- a/examples/basic/src/select.rs +++ b/examples/basic/src/select.rs @@ -10,6 +10,10 @@ pub async fn all_about_select(db: &DbConn) -> Result<(), DbErr> { println!("===== =====\n"); + find_many(db).await?; + + println!("===== =====\n"); + find_one(db).await?; println!("===== =====\n"); @@ -77,6 +81,30 @@ async fn find_together(db: &DbConn) -> Result<(), DbErr> { Ok(()) } +async fn find_many(db: &DbConn) -> Result<(), DbErr> { + print!("find cakes with fruits: "); + + let cakes_with_fruits: Vec<(cake::Model, Vec)> = Cake::find() + .find_with_related(fruit::Entity) + .all(db) + .await?; + + // equivalent; but with a different API + let cakes: Vec = Cake::find().all(db).await?; + let fruits: Vec> = cakes.load_many(fruit::Entity, db).await?; + + println!(); + for (left, right) in cakes_with_fruits + .into_iter() + .zip(cakes.into_iter().zip(fruits.into_iter())) + { + println!("{left:?}\n"); + assert_eq!(left, right); + } + + Ok(()) +} + impl Cake { fn find_by_name(name: &str) -> Select { Self::find().filter(cake::Column::Name.contains(name)) @@ -142,13 +170,24 @@ async fn count_fruits_by_cake(db: &DbConn) -> Result<(), DbErr> { async fn find_many_to_many(db: &DbConn) -> Result<(), DbErr> { print!("find cakes and fillings: "); - let both: Vec<(cake::Model, Vec)> = + let cakes_with_fillings: Vec<(cake::Model, Vec)> = Cake::find().find_with_related(Filling).all(db).await?; + // equivalent; but with a different API + let cakes: Vec = Cake::find().all(db).await?; + let fillings: Vec> = cakes + .load_many_to_many(filling::Entity, cake_filling::Entity, db) + .await?; + println!(); - for bb in both.iter() { - println!("{bb:?}\n"); + for (left, right) in cakes_with_fillings + .into_iter() + .zip(cakes.into_iter().zip(fillings.into_iter())) + { + println!("{left:?}\n"); + assert_eq!(left, right); } + println!(); print!("find fillings for cheese cake: "); diff --git a/src/query/loader.rs b/src/query/loader.rs index 56f263de5..90e5e35bc 100644 --- a/src/query/loader.rs +++ b/src/query/loader.rs @@ -6,32 +6,40 @@ use async_trait::async_trait; use sea_query::{ColumnRef, DynIden, Expr, IntoColumnRef, SimpleExpr, TableRef, ValueTuple}; use std::{collections::HashMap, str::FromStr}; -/// A trait for basic Dataloader +/// Entity, or a Select; to be used as parameters in [`LoaderTrait`] +pub trait EntityOrSelect: Send { + /// If self is Entity, use Entity::find() + fn select(self) -> Select; +} + +/// This trait implements the Data Loader API #[async_trait] pub trait LoaderTrait { /// Source model type Model: ModelTrait; /// Used to eager load has_one relations - async fn load_one(&self, stmt: Select, db: &C) -> Result>, DbErr> + async fn load_one(&self, stmt: S, db: &C) -> Result>, DbErr> where C: ConnectionTrait, R: EntityTrait, R::Model: Send + Sync, + S: EntityOrSelect, <::Model as ModelTrait>::Entity: Related; /// Used to eager load has_many relations - async fn load_many(&self, stmt: Select, db: &C) -> Result>, DbErr> + async fn load_many(&self, stmt: S, db: &C) -> Result>, DbErr> where C: ConnectionTrait, R: EntityTrait, R::Model: Send + Sync, + S: EntityOrSelect, <::Model as ModelTrait>::Entity: Related; /// Used to eager load many_to_many relations - async fn load_many_to_many( + async fn load_many_to_many( &self, - stmt: Select, + stmt: S, via: V, db: &C, ) -> Result>, DbErr> @@ -39,11 +47,30 @@ pub trait LoaderTrait { C: ConnectionTrait, R: EntityTrait, R::Model: Send + Sync, + S: EntityOrSelect, V: EntityTrait, V::Model: Send + Sync, <::Model as ModelTrait>::Entity: Related; } +impl EntityOrSelect for E +where + E: EntityTrait, +{ + fn select(self) -> Select { + E::find() + } +} + +impl EntityOrSelect for Select +where + E: EntityTrait, +{ + fn select(self) -> Select { + self + } +} + #[async_trait] impl LoaderTrait for Vec where @@ -51,29 +78,31 @@ where { type Model = M; - async fn load_one(&self, stmt: Select, db: &C) -> Result>, DbErr> + async fn load_one(&self, stmt: S, db: &C) -> Result>, DbErr> where C: ConnectionTrait, R: EntityTrait, R::Model: Send + Sync, + S: EntityOrSelect, <::Model as ModelTrait>::Entity: Related, { self.as_slice().load_one(stmt, db).await } - async fn load_many(&self, stmt: Select, db: &C) -> Result>, DbErr> + async fn load_many(&self, stmt: S, db: &C) -> Result>, DbErr> where C: ConnectionTrait, R: EntityTrait, R::Model: Send + Sync, + S: EntityOrSelect, <::Model as ModelTrait>::Entity: Related, { self.as_slice().load_many(stmt, db).await } - async fn load_many_to_many( + async fn load_many_to_many( &self, - stmt: Select, + stmt: S, via: V, db: &C, ) -> Result>, DbErr> @@ -81,6 +110,7 @@ where C: ConnectionTrait, R: EntityTrait, R::Model: Send + Sync, + S: EntityOrSelect, V: EntityTrait, V::Model: Send + Sync, <::Model as ModelTrait>::Entity: Related, @@ -96,11 +126,12 @@ where { type Model = M; - async fn load_one(&self, stmt: Select, db: &C) -> Result>, DbErr> + async fn load_one(&self, stmt: S, db: &C) -> Result>, DbErr> where C: ConnectionTrait, R: EntityTrait, R::Model: Send + Sync, + S: EntityOrSelect, <::Model as ModelTrait>::Entity: Related, { // we verify that is HasOne relation @@ -119,7 +150,7 @@ where let condition = prepare_condition(&rel_def.to_tbl, &rel_def.to_col, &keys); - let stmt = as QueryFilter>::filter(stmt, condition); + let stmt = as QueryFilter>::filter(stmt.select(), condition); let data = stmt.all(db).await?; @@ -145,11 +176,12 @@ where Ok(result) } - async fn load_many(&self, stmt: Select, db: &C) -> Result>, DbErr> + async fn load_many(&self, stmt: S, db: &C) -> Result>, DbErr> where C: ConnectionTrait, R: EntityTrait, R::Model: Send + Sync, + S: EntityOrSelect, <::Model as ModelTrait>::Entity: Related, { // we verify that is HasMany relation @@ -169,7 +201,7 @@ where let condition = prepare_condition(&rel_def.to_tbl, &rel_def.to_col, &keys); - let stmt = as QueryFilter>::filter(stmt, condition); + let stmt = as QueryFilter>::filter(stmt.select(), condition); let data = stmt.all(db).await?; @@ -205,9 +237,9 @@ where Ok(result) } - async fn load_many_to_many( + async fn load_many_to_many( &self, - stmt: Select, + stmt: S, via: V, db: &C, ) -> Result>, DbErr> @@ -215,6 +247,7 @@ where C: ConnectionTrait, R: EntityTrait, R::Model: Send + Sync, + S: EntityOrSelect, V: EntityTrait, V::Model: Send + Sync, <::Model as ModelTrait>::Entity: Related, @@ -261,7 +294,7 @@ where let condition = prepare_condition(&rel_def.to_tbl, &rel_def.to_col, &keys); - let stmt = as QueryFilter>::filter(stmt, condition); + let stmt = as QueryFilter>::filter(stmt.select(), condition); let data = stmt.all(db).await?; // Map of R::PK -> R::Model @@ -283,11 +316,7 @@ where let models: Vec<_> = fkeys .into_iter() - .map(|fkey| { - data.get(&format!("{fkey:?}")) - .cloned() - .expect("Failed at finding key on hashmap") - }) + .filter_map(|fkey| data.get(&format!("{fkey:?}")).cloned()) .collect(); models diff --git a/tests/loader_tests.rs b/tests/loader_tests.rs index bc0b5467a..2db148a2c 100644 --- a/tests/loader_tests.rs +++ b/tests/loader_tests.rs @@ -27,7 +27,7 @@ async fn loader_load_one() -> Result<(), DbErr> { .await?; let bakers = baker::Entity::find().all(&ctx.db).await?; - let bakeries = bakers.load_one(bakery::Entity::find(), &ctx.db).await?; + let bakeries = bakers.load_one(bakery::Entity, &ctx.db).await?; assert_eq!(bakers, [baker_1, baker_2, baker_3]); assert_eq!(bakeries, [Some(bakery_0.clone()), Some(bakery_0), None]); @@ -55,7 +55,7 @@ async fn loader_load_many() -> Result<(), DbErr> { let baker_4 = insert_baker(&ctx.db, "Baker 4", bakery_2.id).await?; let bakeries = bakery::Entity::find().all(&ctx.db).await?; - let bakers = bakeries.load_many(baker::Entity::find(), &ctx.db).await?; + let bakers = bakeries.load_many(baker::Entity, &ctx.db).await?; assert_eq!(bakeries, [bakery_1.clone(), bakery_2.clone()]); assert_eq!( @@ -126,8 +126,8 @@ async fn loader_load_many_multi() -> Result<(), DbErr> { let _cake_4 = insert_cake(&ctx.db, "Apple Pie", None).await?; // no one makes apple pie let bakeries = bakery::Entity::find().all(&ctx.db).await?; - let bakers = bakeries.load_many(baker::Entity::find(), &ctx.db).await?; - let cakes = bakeries.load_many(cake::Entity::find(), &ctx.db).await?; + let bakers = bakeries.load_many(baker::Entity, &ctx.db).await?; + let cakes = bakeries.load_many(cake::Entity, &ctx.db).await?; assert_eq!(bakeries, [bakery_1, bakery_2]); assert_eq!(bakers, [vec![baker_1, baker_2], vec![baker_3]]); @@ -152,7 +152,7 @@ async fn loader_load_many_to_many() -> Result<(), DbErr> { let baker_2 = insert_baker(&ctx.db, "Peter", bakery_1.id).await?; let cake_1 = insert_cake(&ctx.db, "Cheesecake", None).await?; - let cake_2 = insert_cake(&ctx.db, "Chocolate", None).await?; + let cake_2 = insert_cake(&ctx.db, "Coffee", None).await?; let cake_3 = insert_cake(&ctx.db, "Chiffon", None).await?; let cake_4 = insert_cake(&ctx.db, "Apple Pie", None).await?; // no one makes apple pie @@ -163,7 +163,7 @@ async fn loader_load_many_to_many() -> Result<(), DbErr> { let bakers = baker::Entity::find().all(&ctx.db).await?; let cakes = bakers - .load_many_to_many(cake::Entity::find(), cakes_bakers::Entity, &ctx.db) + .load_many_to_many(cake::Entity, cakes_bakers::Entity, &ctx.db) .await?; assert_eq!(bakers, [baker_1.clone(), baker_2.clone()]); @@ -175,11 +175,22 @@ async fn loader_load_many_to_many() -> Result<(), DbErr> { ] ); + // same, but apply restrictions on cakes + + let cakes = bakers + .load_many_to_many( + cake::Entity::find().filter(cake::Column::Name.like("Ch%")), + cakes_bakers::Entity, + &ctx.db, + ) + .await?; + assert_eq!(cakes, [vec![cake_1.clone()], vec![cake_3.clone()]]); + // now, start again from cakes let cakes = cake::Entity::find().all(&ctx.db).await?; let bakers = cakes - .load_many_to_many(baker::Entity::find(), cakes_bakers::Entity, &ctx.db) + .load_many_to_many(baker::Entity, cakes_bakers::Entity, &ctx.db) .await?; assert_eq!(cakes, [cake_1, cake_2, cake_3, cake_4]);