diff --git a/examples/basic/src/select.rs b/examples/basic/src/select.rs index 220920ce3..611226a21 100644 --- a/examples/basic/src/select.rs +++ b/examples/basic/src/select.rs @@ -10,6 +10,10 @@ pub async fn all_about_select(db: &DbConn) -> Result<(), DbErr> { println!("===== =====\n"); + find_many(db).await?; + + println!("===== =====\n"); + find_one(db).await?; println!("===== =====\n"); @@ -77,6 +81,30 @@ async fn find_together(db: &DbConn) -> Result<(), DbErr> { Ok(()) } +async fn find_many(db: &DbConn) -> Result<(), DbErr> { + print!("find cakes with fruits: "); + + let cakes_with_fruits: Vec<(cake::Model, Vec)> = Cake::find() + .find_with_related(fruit::Entity) + .all(db) + .await?; + + // equivalent; but with a different API + let cakes: Vec = Cake::find().all(db).await?; + let fruits: Vec> = cakes.load_many(fruit::Entity, db).await?; + + println!(); + for (left, right) in cakes_with_fruits + .into_iter() + .zip(cakes.into_iter().zip(fruits.into_iter())) + { + println!("{left:?}\n"); + assert_eq!(left, right); + } + + Ok(()) +} + impl Cake { fn find_by_name(name: &str) -> Select { Self::find().filter(cake::Column::Name.contains(name)) @@ -142,13 +170,24 @@ async fn count_fruits_by_cake(db: &DbConn) -> Result<(), DbErr> { async fn find_many_to_many(db: &DbConn) -> Result<(), DbErr> { print!("find cakes and fillings: "); - let both: Vec<(cake::Model, Vec)> = + let cakes_with_fillings: Vec<(cake::Model, Vec)> = Cake::find().find_with_related(Filling).all(db).await?; + // equivalent; but with a different API + let cakes: Vec = Cake::find().all(db).await?; + let fillings: Vec> = cakes + .load_many_to_many(filling::Entity, cake_filling::Entity, db) + .await?; + println!(); - for bb in both.iter() { - println!("{bb:?}\n"); + for (left, right) in cakes_with_fillings + .into_iter() + .zip(cakes.into_iter().zip(fillings.into_iter())) + { + println!("{left:?}\n"); + assert_eq!(left, right); } + println!(); print!("find fillings for cheese cake: "); diff --git a/src/entity/relation.rs b/src/entity/relation.rs index ad291f0f7..cae517888 100644 --- a/src/entity/relation.rs +++ b/src/entity/relation.rs @@ -7,7 +7,7 @@ use sea_query::{ use std::fmt::Debug; /// Defines the type of relationship -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq, Eq)] pub enum RelationType { /// An Entity has one relationship HasOne, diff --git a/src/query/loader.rs b/src/query/loader.rs index 5d9dd37dc..90e5e35bc 100644 --- a/src/query/loader.rs +++ b/src/query/loader.rs @@ -1,34 +1,76 @@ use crate::{ - error::*, ColumnTrait, Condition, ConnectionTrait, DbErr, EntityTrait, Identity, ModelTrait, - QueryFilter, Related, RelationType, Select, + error::*, Condition, ConnectionTrait, DbErr, EntityTrait, Identity, ModelTrait, QueryFilter, + Related, RelationType, Select, }; use async_trait::async_trait; -use sea_query::{Expr, IntoColumnRef, SimpleExpr, ValueTuple}; +use sea_query::{ColumnRef, DynIden, Expr, IntoColumnRef, SimpleExpr, TableRef, ValueTuple}; use std::{collections::HashMap, str::FromStr}; -/// A trait for basic Dataloader +/// Entity, or a Select; to be used as parameters in [`LoaderTrait`] +pub trait EntityOrSelect: Send { + /// If self is Entity, use Entity::find() + fn select(self) -> Select; +} + +/// This trait implements the Data Loader API #[async_trait] pub trait LoaderTrait { /// Source model type Model: ModelTrait; /// Used to eager load has_one relations - async fn load_one(&self, stmt: Select, db: &C) -> Result>, DbErr> + async fn load_one(&self, stmt: S, db: &C) -> Result>, DbErr> + where + C: ConnectionTrait, + R: EntityTrait, + R::Model: Send + Sync, + S: EntityOrSelect, + <::Model as ModelTrait>::Entity: Related; + + /// Used to eager load has_many relations + async fn load_many(&self, stmt: S, db: &C) -> Result>, DbErr> where C: ConnectionTrait, R: EntityTrait, R::Model: Send + Sync, + S: EntityOrSelect, <::Model as ModelTrait>::Entity: Related; - /// Used to eager load has_many relations - async fn load_many(&self, stmt: Select, db: &C) -> Result>, DbErr> + /// Used to eager load many_to_many relations + async fn load_many_to_many( + &self, + stmt: S, + via: V, + db: &C, + ) -> Result>, DbErr> where C: ConnectionTrait, R: EntityTrait, R::Model: Send + Sync, + S: EntityOrSelect, + V: EntityTrait, + V::Model: Send + Sync, <::Model as ModelTrait>::Entity: Related; } +impl EntityOrSelect for E +where + E: EntityTrait, +{ + fn select(self) -> Select { + E::find() + } +} + +impl EntityOrSelect for Select +where + E: EntityTrait, +{ + fn select(self) -> Select { + self + } +} + #[async_trait] impl LoaderTrait for Vec where @@ -36,25 +78,45 @@ where { type Model = M; - async fn load_one(&self, stmt: Select, db: &C) -> Result>, DbErr> + async fn load_one(&self, stmt: S, db: &C) -> Result>, DbErr> where C: ConnectionTrait, R: EntityTrait, R::Model: Send + Sync, + S: EntityOrSelect, <::Model as ModelTrait>::Entity: Related, { self.as_slice().load_one(stmt, db).await } - async fn load_many(&self, stmt: Select, db: &C) -> Result>, DbErr> + async fn load_many(&self, stmt: S, db: &C) -> Result>, DbErr> where C: ConnectionTrait, R: EntityTrait, R::Model: Send + Sync, + S: EntityOrSelect, <::Model as ModelTrait>::Entity: Related, { self.as_slice().load_many(stmt, db).await } + + async fn load_many_to_many( + &self, + stmt: S, + via: V, + db: &C, + ) -> Result>, DbErr> + where + C: ConnectionTrait, + R: EntityTrait, + R::Model: Send + Sync, + S: EntityOrSelect, + V: EntityTrait, + V::Model: Send + Sync, + <::Model as ModelTrait>::Entity: Related, + { + self.as_slice().load_many_to_many(stmt, via, db).await + } } #[async_trait] @@ -64,19 +126,21 @@ where { type Model = M; - async fn load_one(&self, stmt: Select, db: &C) -> Result>, DbErr> + async fn load_one(&self, stmt: S, db: &C) -> Result>, DbErr> where C: ConnectionTrait, R: EntityTrait, R::Model: Send + Sync, + S: EntityOrSelect, <::Model as ModelTrait>::Entity: Related, { + // we verify that is HasOne relation + if <<::Model as ModelTrait>::Entity as Related>::via().is_some() { + return Err(query_err("Relation is ManytoMany instead of HasOne")); + } let rel_def = <<::Model as ModelTrait>::Entity as Related>::to(); - - // we verify that is has_one relation - match (rel_def).rel_type { - RelationType::HasOne => (), - RelationType::HasMany => return Err(type_err("Relation is HasMany instead of HasOne")), + if rel_def.rel_type == RelationType::HasMany { + return Err(query_err("Relation is HasMany instead of HasOne")); } let keys: Vec = self @@ -84,9 +148,9 @@ where .map(|model: &M| extract_key(&rel_def.from_col, model)) .collect(); - let condition = prepare_condition::<::Model>(&rel_def.to_col, &keys); + let condition = prepare_condition(&rel_def.to_tbl, &rel_def.to_col, &keys); - let stmt = as QueryFilter>::filter(stmt, condition); + let stmt = as QueryFilter>::filter(stmt.select(), condition); let data = stmt.all(db).await?; @@ -112,19 +176,22 @@ where Ok(result) } - async fn load_many(&self, stmt: Select, db: &C) -> Result>, DbErr> + async fn load_many(&self, stmt: S, db: &C) -> Result>, DbErr> where C: ConnectionTrait, R: EntityTrait, R::Model: Send + Sync, + S: EntityOrSelect, <::Model as ModelTrait>::Entity: Related, { - let rel_def = <<::Model as ModelTrait>::Entity as Related>::to(); + // we verify that is HasMany relation - // we verify that is has_many relation - match (rel_def).rel_type { - RelationType::HasMany => (), - RelationType::HasOne => return Err(type_err("Relation is HasOne instead of HasMany")), + if <<::Model as ModelTrait>::Entity as Related>::via().is_some() { + return Err(query_err("Relation is ManyToMany instead of HasMany")); + } + let rel_def = <<::Model as ModelTrait>::Entity as Related>::to(); + if rel_def.rel_type == RelationType::HasOne { + return Err(query_err("Relation is HasOne instead of HasMany")); } let keys: Vec = self @@ -132,9 +199,9 @@ where .map(|model: &M| extract_key(&rel_def.from_col, model)) .collect(); - let condition = prepare_condition::<::Model>(&rel_def.to_col, &keys); + let condition = prepare_condition(&rel_def.to_tbl, &rel_def.to_col, &keys); - let stmt = as QueryFilter>::filter(stmt, condition); + let stmt = as QueryFilter>::filter(stmt.select(), condition); let data = stmt.all(db).await?; @@ -169,6 +236,103 @@ where Ok(result) } + + async fn load_many_to_many( + &self, + stmt: S, + via: V, + db: &C, + ) -> Result>, DbErr> + where + C: ConnectionTrait, + R: EntityTrait, + R::Model: Send + Sync, + S: EntityOrSelect, + V: EntityTrait, + V::Model: Send + Sync, + <::Model as ModelTrait>::Entity: Related, + { + if let Some(via_rel) = + <<::Model as ModelTrait>::Entity as Related>::via() + { + let rel_def = + <<::Model as ModelTrait>::Entity as Related>::to(); + if rel_def.rel_type != RelationType::HasOne { + return Err(query_err("Relation to is not HasOne")); + } + + if !cmp_table_ref(&via_rel.to_tbl, &via.table_ref()) { + return Err(query_err(format!( + "The given via Entity is incorrect: expected: {:?}, given: {:?}", + via_rel.to_tbl, + via.table_ref() + ))); + } + + let pkeys: Vec = self + .iter() + .map(|model: &M| extract_key(&via_rel.from_col, model)) + .collect(); + + // Map of M::PK -> Vec + let mut keymap: HashMap> = Default::default(); + + let keys: Vec = { + let condition = prepare_condition(&via_rel.to_tbl, &via_rel.to_col, &pkeys); + let stmt = V::find().filter(condition); + let data = stmt.all(db).await?; + data.into_iter().for_each(|model| { + let pk = format!("{:?}", extract_key(&via_rel.to_col, &model)); + let entry = keymap.entry(pk).or_default(); + + let fk = extract_key(&rel_def.from_col, &model); + entry.push(fk); + }); + + keymap.values().flatten().cloned().collect() + }; + + let condition = prepare_condition(&rel_def.to_tbl, &rel_def.to_col, &keys); + + let stmt = as QueryFilter>::filter(stmt.select(), condition); + + let data = stmt.all(db).await?; + // Map of R::PK -> R::Model + let data: HashMap::Model> = data + .into_iter() + .map(|model| { + let key = format!("{:?}", extract_key(&rel_def.to_col, &model)); + (key, model) + }) + .collect(); + + let result: Vec> = pkeys + .into_iter() + .map(|pkey| { + let fkeys = keymap + .get(&format!("{pkey:?}")) + .cloned() + .unwrap_or_default(); + + let models: Vec<_> = fkeys + .into_iter() + .filter_map(|fkey| data.get(&format!("{fkey:?}")).cloned()) + .collect(); + + models + }) + .collect(); + + Ok(result) + } else { + return Err(query_err("Relation is not ManyToMany")); + } + } +} + +fn cmp_table_ref(left: &TableRef, right: &TableRef) -> bool { + // not ideal; but + format!("{left:?}") == format!("{right:?}") } fn extract_key(target_col: &Identity, model: &Model) -> ValueTuple @@ -222,54 +386,35 @@ where } } -fn prepare_condition(col: &Identity, keys: &[ValueTuple]) -> Condition -where - M: ModelTrait, -{ +fn prepare_condition(table: &TableRef, col: &Identity, keys: &[ValueTuple]) -> Condition { match col { Identity::Unary(column_a) => { - let column_a: ::Column = - <::Column as FromStr>::from_str(&column_a.to_string()) - .unwrap_or_else(|_| panic!("Failed at mapping string to column *A:1")); - Condition::all().add(ColumnTrait::is_in( - &column_a, - keys.iter().cloned().flatten(), - )) - } - Identity::Binary(column_a, column_b) => { - let column_a: ::Column = - <::Column as FromStr>::from_str(&column_a.to_string()) - .unwrap_or_else(|_| panic!("Failed at mapping string to column *A:2")); - let column_b: ::Column = - <::Column as FromStr>::from_str(&column_b.to_string()) - .unwrap_or_else(|_| panic!("Failed at mapping string to column *B:2")); - Condition::all().add( - Expr::tuple([ - SimpleExpr::Column(column_a.into_column_ref()), - SimpleExpr::Column(column_b.into_column_ref()), - ]) - .in_tuples(keys.iter().cloned()), - ) - } - Identity::Ternary(column_a, column_b, column_c) => { - let column_a: ::Column = - <::Column as FromStr>::from_str(&column_a.to_string()) - .unwrap_or_else(|_| panic!("Failed at mapping string to column *A:3")); - let column_b: ::Column = - <::Column as FromStr>::from_str(&column_b.to_string()) - .unwrap_or_else(|_| panic!("Failed at mapping string to column *B:3")); - let column_c: ::Column = - <::Column as FromStr>::from_str(&column_c.to_string()) - .unwrap_or_else(|_| panic!("Failed at mapping string to column *C:3")); - Condition::all().add( - Expr::tuple([ - SimpleExpr::Column(column_a.into_column_ref()), - SimpleExpr::Column(column_b.into_column_ref()), - SimpleExpr::Column(column_c.into_column_ref()), - ]) - .in_tuples(keys.iter().cloned()), - ) + let column_a = table_column(table, column_a); + Condition::all().add(Expr::col(column_a).is_in(keys.iter().cloned().flatten())) } + Identity::Binary(column_a, column_b) => Condition::all().add( + Expr::tuple([ + SimpleExpr::Column(table_column(table, column_a)), + SimpleExpr::Column(table_column(table, column_b)), + ]) + .in_tuples(keys.iter().cloned()), + ), + Identity::Ternary(column_a, column_b, column_c) => Condition::all().add( + Expr::tuple([ + SimpleExpr::Column(table_column(table, column_a)), + SimpleExpr::Column(table_column(table, column_b)), + SimpleExpr::Column(table_column(table, column_c)), + ]) + .in_tuples(keys.iter().cloned()), + ), + } +} + +fn table_column(tbl: &TableRef, col: &DynIden) -> ColumnRef { + match tbl.to_owned() { + TableRef::Table(tbl) => (tbl, col.clone()).into_column_ref(), + TableRef::SchemaTable(sch, tbl) => (sch, tbl, col.clone()).into_column_ref(), + val => unimplemented!("Unsupported TableRef {val:?}"), } } diff --git a/tests/loader_tests.rs b/tests/loader_tests.rs index 0026fc85c..2db148a2c 100644 --- a/tests/loader_tests.rs +++ b/tests/loader_tests.rs @@ -13,65 +13,24 @@ async fn loader_load_one() -> Result<(), DbErr> { let ctx = TestContext::new("loader_test_load_one").await; create_tables(&ctx.db).await?; - let bakery = insert_bakery(&ctx.db, "SeaSide Bakery").await?; + let bakery_0 = insert_bakery(&ctx.db, "SeaSide Bakery").await?; - let baker_1 = insert_baker(&ctx.db, "Baker 1", bakery.id).await?; - - let baker_2 = baker::ActiveModel { - name: Set("Baker 2".to_owned()), + let baker_1 = insert_baker(&ctx.db, "Baker 1", bakery_0.id).await?; + let baker_2 = insert_baker(&ctx.db, "Baker 2", bakery_0.id).await?; + let baker_3 = baker::ActiveModel { + name: Set("Baker 3".to_owned()), contact_details: Set(serde_json::json!({})), bakery_id: Set(None), ..Default::default() } .insert(&ctx.db) - .await - .expect("could not insert baker"); - - let bakers = baker::Entity::find() - .all(&ctx.db) - .await - .expect("Should load bakers"); - - let bakeries = bakers - .load_one(bakery::Entity::find(), &ctx.db) - .await - .expect("Should load bakeries"); - - assert_eq!(bakers, [baker_1, baker_2]); - - assert_eq!(bakeries, [Some(bakery), None]); - - Ok(()) -} - -#[sea_orm_macros::test] -#[cfg(any( - feature = "sqlx-mysql", - feature = "sqlx-sqlite", - feature = "sqlx-postgres" -))] -async fn loader_load_one_complex() -> Result<(), DbErr> { - let ctx = TestContext::new("loader_test_load_one_complex").await; - create_tables(&ctx.db).await?; - - let bakery = insert_bakery(&ctx.db, "SeaSide Bakery").await?; - - let baker_1 = insert_baker(&ctx.db, "Baker 1", bakery.id).await?; - let baker_2 = insert_baker(&ctx.db, "Baker 2", bakery.id).await?; - - let bakers = baker::Entity::find() - .all(&ctx.db) - .await - .expect("Should load bakers"); + .await?; - let bakeries = bakers - .load_one(bakery::Entity::find(), &ctx.db) - .await - .expect("Should load bakeries"); + let bakers = baker::Entity::find().all(&ctx.db).await?; + let bakeries = bakers.load_one(bakery::Entity, &ctx.db).await?; - assert_eq!(bakers, [baker_1, baker_2]); - - assert_eq!(bakeries, [Some(bakery.clone()), Some(bakery.clone())]); + assert_eq!(bakers, [baker_1, baker_2, baker_3]); + assert_eq!(bakeries, [Some(bakery_0.clone()), Some(bakery_0), None]); Ok(()) } @@ -95,23 +54,26 @@ async fn loader_load_many() -> Result<(), DbErr> { let baker_3 = insert_baker(&ctx.db, "John", bakery_2.id).await?; let baker_4 = insert_baker(&ctx.db, "Baker 4", bakery_2.id).await?; - let bakeries = bakery::Entity::find() - .all(&ctx.db) - .await - .expect("Should load bakeries"); + let bakeries = bakery::Entity::find().all(&ctx.db).await?; + let bakers = bakeries.load_many(baker::Entity, &ctx.db).await?; + + assert_eq!(bakeries, [bakery_1.clone(), bakery_2.clone()]); + assert_eq!( + bakers, + [ + [baker_1.clone(), baker_2.clone()], + [baker_3.clone(), baker_4.clone()] + ] + ); + + // load bakers again but with additional condition let bakers = bakeries .load_many( baker::Entity::find().filter(baker::Column::Name.like("Baker%")), &ctx.db, ) - .await - .expect("Should load bakers"); - - println!("A: {bakers:?}"); - println!("B: {bakeries:?}"); - - assert_eq!(bakeries, [bakery_1, bakery_2]); + .await?; assert_eq!( bakers, @@ -121,12 +83,22 @@ async fn loader_load_many() -> Result<(), DbErr> { ] ); - let bakers = bakeries - .load_many(baker::Entity::find(), &ctx.db) - .await - .expect("Should load bakers"); + // now, start from baker - assert_eq!(bakers, [[baker_1, baker_2], [baker_3, baker_4]]); + let bakers = baker::Entity::find().all(&ctx.db).await?; + let bakeries = bakers.load_one(bakery::Entity::find(), &ctx.db).await?; + + // note that two bakers share the same bakery + assert_eq!(bakers, [baker_1, baker_2, baker_3, baker_4]); + assert_eq!( + bakeries, + [ + Some(bakery_1.clone()), + Some(bakery_1), + Some(bakery_2.clone()), + Some(bakery_2) + ] + ); Ok(()) } @@ -137,8 +109,8 @@ async fn loader_load_many() -> Result<(), DbErr> { feature = "sqlx-sqlite", feature = "sqlx-postgres" ))] -async fn loader_load_many_many() -> Result<(), DbErr> { - let ctx = TestContext::new("loader_test_load_many_many").await; +async fn loader_load_many_multi() -> Result<(), DbErr> { + let ctx = TestContext::new("loader_test_load_many_multi").await; create_tables(&ctx.db).await?; let bakery_1 = insert_bakery(&ctx.db, "SeaSide Bakery").await?; @@ -148,17 +120,14 @@ async fn loader_load_many_many() -> Result<(), DbErr> { let baker_2 = insert_baker(&ctx.db, "Jane", bakery_1.id).await?; let baker_3 = insert_baker(&ctx.db, "Peter", bakery_2.id).await?; - let cake_1 = insert_cake(&ctx.db, "Cheesecake", bakery_1.id).await?; - let cake_2 = insert_cake(&ctx.db, "Chocolate", bakery_2.id).await?; - let cake_3 = insert_cake(&ctx.db, "Chiffon", bakery_2.id).await?; + let cake_1 = insert_cake(&ctx.db, "Cheesecake", Some(bakery_1.id)).await?; + let cake_2 = insert_cake(&ctx.db, "Chocolate", Some(bakery_2.id)).await?; + let cake_3 = insert_cake(&ctx.db, "Chiffon", Some(bakery_2.id)).await?; + let _cake_4 = insert_cake(&ctx.db, "Apple Pie", None).await?; // no one makes apple pie let bakeries = bakery::Entity::find().all(&ctx.db).await?; - let bakers = bakeries.load_many(baker::Entity::find(), &ctx.db).await?; - let cakes = bakeries.load_many(cake::Entity::find(), &ctx.db).await?; - - println!("{bakers:?}"); - println!("{bakeries:?}"); - println!("{cakes:?}"); + let bakers = bakeries.load_many(baker::Entity, &ctx.db).await?; + let cakes = bakeries.load_many(cake::Entity, &ctx.db).await?; assert_eq!(bakeries, [bakery_1, bakery_2]); assert_eq!(bakers, [vec![baker_1, baker_2], vec![baker_3]]); @@ -167,6 +136,77 @@ async fn loader_load_many_many() -> Result<(), DbErr> { Ok(()) } +#[sea_orm_macros::test] +#[cfg(any( + feature = "sqlx-mysql", + feature = "sqlx-sqlite", + feature = "sqlx-postgres" +))] +async fn loader_load_many_to_many() -> Result<(), DbErr> { + let ctx = TestContext::new("loader_test_load_many_to_many").await; + create_tables(&ctx.db).await?; + + let bakery_1 = insert_bakery(&ctx.db, "SeaSide Bakery").await?; + + let baker_1 = insert_baker(&ctx.db, "Jane", bakery_1.id).await?; + let baker_2 = insert_baker(&ctx.db, "Peter", bakery_1.id).await?; + + let cake_1 = insert_cake(&ctx.db, "Cheesecake", None).await?; + let cake_2 = insert_cake(&ctx.db, "Coffee", None).await?; + let cake_3 = insert_cake(&ctx.db, "Chiffon", None).await?; + let cake_4 = insert_cake(&ctx.db, "Apple Pie", None).await?; // no one makes apple pie + + insert_cake_baker(&ctx.db, baker_1.id, cake_1.id).await?; + insert_cake_baker(&ctx.db, baker_1.id, cake_2.id).await?; + insert_cake_baker(&ctx.db, baker_2.id, cake_2.id).await?; + insert_cake_baker(&ctx.db, baker_2.id, cake_3.id).await?; + + let bakers = baker::Entity::find().all(&ctx.db).await?; + let cakes = bakers + .load_many_to_many(cake::Entity, cakes_bakers::Entity, &ctx.db) + .await?; + + assert_eq!(bakers, [baker_1.clone(), baker_2.clone()]); + assert_eq!( + cakes, + [ + vec![cake_1.clone(), cake_2.clone()], + vec![cake_2.clone(), cake_3.clone()] + ] + ); + + // same, but apply restrictions on cakes + + let cakes = bakers + .load_many_to_many( + cake::Entity::find().filter(cake::Column::Name.like("Ch%")), + cakes_bakers::Entity, + &ctx.db, + ) + .await?; + assert_eq!(cakes, [vec![cake_1.clone()], vec![cake_3.clone()]]); + + // now, start again from cakes + + let cakes = cake::Entity::find().all(&ctx.db).await?; + let bakers = cakes + .load_many_to_many(baker::Entity, cakes_bakers::Entity, &ctx.db) + .await?; + + assert_eq!(cakes, [cake_1, cake_2, cake_3, cake_4]); + assert_eq!( + bakers, + [ + vec![baker_1.clone()], + vec![baker_1.clone(), baker_2.clone()], + vec![baker_2.clone()], + vec![] + ] + ); + + Ok(()) +} + pub async fn insert_bakery(db: &DbConn, name: &str) -> Result { bakery::ActiveModel { name: Set(name.to_owned()), @@ -188,14 +228,31 @@ pub async fn insert_baker(db: &DbConn, name: &str, bakery_id: i32) -> Result Result { +pub async fn insert_cake( + db: &DbConn, + name: &str, + bakery_id: Option, +) -> Result { cake::ActiveModel { name: Set(name.to_owned()), price: Set(rust_decimal::Decimal::ONE), gluten_free: Set(false), - bakery_id: Set(Some(bakery_id)), + bakery_id: Set(bakery_id), ..Default::default() } .insert(db) .await } + +pub async fn insert_cake_baker( + db: &DbConn, + baker_id: i32, + cake_id: i32, +) -> Result { + cakes_bakers::ActiveModel { + cake_id: Set(cake_id), + baker_id: Set(baker_id), + } + .insert(db) + .await +}