diff --git a/quaint/src/visitor.rs b/quaint/src/visitor.rs index 8424bc7fbb2b..c205b49dd279 100644 --- a/quaint/src/visitor.rs +++ b/quaint/src/visitor.rs @@ -1004,6 +1004,20 @@ pub trait Visitor<'a> { Ok(()) } + fn visit_min(&mut self, min: Minimum<'a>) -> Result { + self.write("MIN")?; + self.surround_with("(", ")", |ref mut s| s.visit_column(min.column))?; + + Ok(()) + } + + fn visit_max(&mut self, max: Maximum<'a>) -> Result { + self.write("MAX")?; + self.surround_with("(", ")", |ref mut s| s.visit_column(max.column))?; + + Ok(()) + } + fn visit_function(&mut self, fun: Function<'a>) -> Result { match fun.typ_ { FunctionType::RowNumber(fun_rownum) => { @@ -1046,12 +1060,10 @@ pub trait Visitor<'a> { self.surround_with("(", ")", |ref mut s| s.visit_expression(*upper.expression))?; } FunctionType::Minimum(min) => { - self.write("MIN")?; - self.surround_with("(", ")", |ref mut s| s.visit_column(min.column))?; + self.visit_min(min)?; } FunctionType::Maximum(max) => { - self.write("MAX")?; - self.surround_with("(", ")", |ref mut s| s.visit_column(max.column))?; + self.visit_max(max)?; } FunctionType::Coalesce(coalesce) => { self.write("COALESCE")?; diff --git a/quaint/src/visitor/postgres.rs b/quaint/src/visitor/postgres.rs index b587a7b5b0ec..648b3f0dc1ec 100644 --- a/quaint/src/visitor/postgres.rs +++ b/quaint/src/visitor/postgres.rs @@ -627,6 +627,34 @@ impl<'a> Visitor<'a> for Postgres<'a> { Ok(()) } + + fn visit_min(&mut self, min: Minimum<'a>) -> visitor::Result { + // If the inner column is a selected enum, then we cast the result of MIN(enum)::text instead of casting the inner enum column, which changes the behavior of MIN. + let should_cast = min.column.is_enum && min.column.is_selected; + + self.write("MIN")?; + self.surround_with("(", ")", |ref mut s| s.visit_column(min.column.set_is_selected(false)))?; + + if should_cast { + self.write("::text")?; + } + + Ok(()) + } + + fn visit_max(&mut self, max: Maximum<'a>) -> visitor::Result { + // If the inner column is a selected enum, then we cast the result of MAX(enum)::text instead of casting the inner enum column, which changes the behavior of MAX. + let should_cast = max.column.is_enum && max.column.is_selected; + + self.write("MAX")?; + self.surround_with("(", ")", |ref mut s| s.visit_column(max.column.set_is_selected(false)))?; + + if should_cast { + self.write("::text")?; + } + + Ok(()) + } } #[cfg(test)] @@ -1157,4 +1185,15 @@ mod tests { assert_eq!("SELECT \"User\".*, \"Toto\".* FROM \"User\" LEFT JOIN \"Post\" AS \"p\" ON \"p\".\"userId\" = \"User\".\"id\", \"Toto\"", sql); } + + #[test] + fn enum_cast_text_in_min_max_should_be_outside() { + let enum_col = Column::from("enum").set_is_enum(true).set_is_selected(true); + let q = Select::from_table("User") + .value(min(enum_col.clone())) + .value(max(enum_col)); + let (sql, _) = Postgres::build(q).unwrap(); + + assert_eq!("SELECT MIN(\"enum\")::text, MAX(\"enum\")::text FROM \"User\"", sql); + } } diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/aggregation/group_by.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/aggregation/group_by.rs index 5abbbfe4bdf4..e372c4525f08 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/aggregation/group_by.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/aggregation/group_by.rs @@ -515,6 +515,54 @@ mod aggregation_group_by { Ok(()) } + fn schema_21789() -> String { + let schema = indoc! { + r#"model Test { + #id(id, Int, @id) + group Int + color Color + } + + enum Color { + blue + red + green + } + "# + }; + + schema.to_owned() + } + + // regression test for https://github.com/prisma/prisma/issues/21789 + #[connector_test(schema(schema_21789), only(Postgres, CockroachDB))] + async fn regression_21789(runner: Runner) -> TestResult<()> { + run_query!( + &runner, + r#"mutation { createOneTest(data: { id: 1, group: 1, color: "red" }) { id } }"# + ); + run_query!( + &runner, + r#"mutation { createOneTest(data: { id: 2, group: 2, color: "green" }) { id } }"# + ); + run_query!( + &runner, + r#"mutation { createOneTest(data: { id: 3, group: 1, color: "blue" }) { id } }"# + ); + + insta::assert_snapshot!( + run_query!(&runner, r#"{ aggregateTest { _max { color } _min { color } } }"#), + @r###"{"data":{"aggregateTest":{"_max":{"color":"green"},"_min":{"color":"blue"}}}}"### + ); + + insta::assert_snapshot!( + run_query!(&runner, r#"{ groupByTest(by: [group], orderBy: { group: asc }) { group _max { color } _min { color } } }"#), + @r###"{"data":{"groupByTest":[{"group":1,"_max":{"color":"red"},"_min":{"color":"blue"}},{"group":2,"_max":{"color":"green"},"_min":{"color":"green"}}]}}"### + ); + + Ok(()) + } + /// Error cases #[connector_test] diff --git a/query-engine/connectors/sql-query-connector/src/database/operations/read.rs b/query-engine/connectors/sql-query-connector/src/database/operations/read.rs index 470628de1132..8b3bf9031019 100644 --- a/query-engine/connectors/sql-query-connector/src/database/operations/read.rs +++ b/query-engine/connectors/sql-query-connector/src/database/operations/read.rs @@ -18,7 +18,13 @@ pub(crate) async fn get_single_record( aggr_selections: &[RelAggregationSelection], ctx: &Context<'_>, ) -> crate::Result> { - let query = read::get_records(model, selected_fields.as_columns(ctx), aggr_selections, filter, ctx); + let query = read::get_records( + model, + selected_fields.as_columns(ctx).mark_all_selected(), + aggr_selections, + filter, + ctx, + ); let mut field_names: Vec<_> = selected_fields.db_names().collect(); let mut aggr_field_names: Vec<_> = aggr_selections.iter().map(|aggr_sel| aggr_sel.db_alias()).collect(); @@ -104,7 +110,13 @@ pub(crate) async fn get_many_records( let mut futures = FuturesUnordered::new(); for args in batches.into_iter() { - let query = read::get_records(model, selected_fields.as_columns(ctx), aggr_selections, args, ctx); + let query = read::get_records( + model, + selected_fields.as_columns(ctx).mark_all_selected(), + aggr_selections, + args, + ctx, + ); futures.push(conn.filter(query.into(), meta.as_slice(), ctx)); } @@ -122,7 +134,7 @@ pub(crate) async fn get_many_records( _ => { let query = read::get_records( model, - selected_fields.as_columns(ctx), + selected_fields.as_columns(ctx).mark_all_selected(), aggr_selections, query_arguments, ctx, diff --git a/query-engine/connectors/sql-query-connector/src/model_extensions/column.rs b/query-engine/connectors/sql-query-connector/src/model_extensions/column.rs index 445bada9c45c..d3139082975b 100644 --- a/query-engine/connectors/sql-query-connector/src/model_extensions/column.rs +++ b/query-engine/connectors/sql-query-connector/src/model_extensions/column.rs @@ -7,6 +7,15 @@ pub struct ColumnIterator { inner: Box> + 'static>, } +impl ColumnIterator { + /// Sets all columns as selected. This is a hack that we use to help the Postgres SQL visitor cast enum columns to text to avoid some driver roundtrips otherwise needed to resolve enum types. + pub fn mark_all_selected(self) -> Self { + ColumnIterator { + inner: Box::new(self.inner.map(|c| c.set_is_selected(true))), + } + } +} + impl Iterator for ColumnIterator { type Item = Column<'static>; diff --git a/query-engine/connectors/sql-query-connector/src/query_builder/read.rs b/query-engine/connectors/sql-query-connector/src/query_builder/read.rs index a5385f1dd56a..3f73bb51b2d5 100644 --- a/query-engine/connectors/sql-query-connector/src/query_builder/read.rs +++ b/query-engine/connectors/sql-query-connector/src/query_builder/read.rs @@ -124,9 +124,7 @@ where T: SelectDefinition, { let (select, additional_selection_set) = query.into_select(model, aggr_selections, ctx); - let select = columns - .map(|c| c.set_is_selected(true)) - .fold(select, |acc, col| acc.column(col)); + let select = columns.fold(select, |acc, col| acc.column(col)); let select = select.append_trace(&Span::current()).add_trace_id(ctx.trace_id); @@ -176,7 +174,11 @@ pub(crate) fn aggregate( .append_trace(&Span::current()) .add_trace_id(ctx.trace_id), |select, next_op| match next_op { - AggregationSelection::Field(field) => select.column(Column::from(field.db_name().to_owned())), + AggregationSelection::Field(field) => select.column( + Column::from(field.db_name().to_owned()) + .set_is_enum(field.type_identifier().is_enum()) + .set_is_selected(true), + ), AggregationSelection::Count { all, fields } => { let select = fields.iter().fold(select, |select, next_field| { @@ -199,11 +201,15 @@ pub(crate) fn aggregate( }), AggregationSelection::Min(fields) => fields.iter().fold(select, |select, next_field| { - select.value(min(Column::from(next_field.db_name().to_owned()))) + select.value(min(Column::from(next_field.db_name().to_owned()) + .set_is_enum(next_field.type_identifier().is_enum()) + .set_is_selected(true))) }), AggregationSelection::Max(fields) => fields.iter().fold(select, |select, next_field| { - select.value(max(Column::from(next_field.db_name().to_owned()))) + select.value(max(Column::from(next_field.db_name().to_owned()) + .set_is_enum(next_field.type_identifier().is_enum()) + .set_is_selected(true))) }), }, ) @@ -243,11 +249,11 @@ pub(crate) fn group_by_aggregate( }), AggregationSelection::Min(fields) => fields.iter().fold(select, |select, next_field| { - select.value(min(next_field.as_column(ctx))) + select.value(min(next_field.as_column(ctx).set_is_selected(true))) }), AggregationSelection::Max(fields) => fields.iter().fold(select, |select, next_field| { - select.value(max(next_field.as_column(ctx))) + select.value(max(next_field.as_column(ctx).set_is_selected(true))) }), });