1+ from django .core .exceptions import EmptyResultSet , FullResultSet
12from django .db import NotSupportedError
23from django .db .models .aggregates import (
34 Aggregate ,
4- AggregateFilter ,
55 Count ,
66 StdDev ,
77 StringAgg ,
88 Variance ,
99)
10- from django .db .models .expressions import Case , Col , Value , When
10+ from django .db .models .expressions import Case , Value , When
1111from django .db .models .lookups import IsNull
1212
1313from .query_utils import process_lhs
@@ -20,26 +20,26 @@ def aggregate(self, compiler, connection, operator=None, resolve_inner_expressio
2020 # TODO: isinstance(self.filter, Col) works around failure of
2121 # aggregation.tests.AggregateTestCase.test_distinct_on_aggregate. Is this
2222 # correct?
23- if self .filter is not None and not isinstance ( self . filter , Col ) :
23+ if self .filter is not None :
2424 # Generate a CASE statement for this aggregate.
25- node = self .copy ()
26- node .filter = None
27- source_expressions = node .get_source_expressions ()
28- condition = When (self .filter , then = source_expressions [0 ])
29- node .set_source_expressions ([Case (condition ), * source_expressions [1 :]])
25+ try :
26+ lhs_mql = self .filter .as_mql (compiler , connection , as_expr = True )
27+ except NotSupportedError :
28+ source_expressions = self .get_source_expressions ()
29+ condition = Case (When (self .filter .condition , then = source_expressions [0 ]))
30+ lhs_mql = condition .as_mql (compiler , connection )
31+ except FullResultSet :
32+ lhs_mql = source_expressions [0 ].as_mql (compiler , connection , as_expr = True )
33+ except EmptyResultSet :
34+ lhs_mql = Value (None ).as_mql (compiler , connection , as_expr = True )
3035 else :
31- node = self
32- lhs_mql = process_lhs (node , compiler , connection , as_expr = True )
36+ lhs_mql = process_lhs (self , compiler , connection , as_expr = True )
3337 if resolve_inner_expression :
3438 return lhs_mql
3539 operator = operator or MONGO_AGGREGATIONS .get (self .__class__ , self .function .lower ())
3640 return {f"${ operator } " : lhs_mql }
3741
3842
39- def aggregate_filter (self , compiler , connection ):
40- return self .condition .as_mql (compiler , connection , as_expr = True )
41-
42-
4343def count (self , compiler , connection , resolve_inner_expression = False ):
4444 """
4545 When resolve_inner_expression=True, return the MQL that resolves as a
@@ -48,14 +48,19 @@ def count(self, compiler, connection, resolve_inner_expression=False):
4848 """
4949 if not self .distinct or resolve_inner_expression :
5050 if self .filter :
51- node = self .copy ()
52- node .filter = None
53- source_expressions = node .get_source_expressions ()
54- condition = When (
55- self .filter , then = Case (When (IsNull (source_expressions [0 ], False ), then = Value (1 )))
56- )
57- node .set_source_expressions ([Case (condition ), * source_expressions [1 :]])
58- inner_expression = process_lhs (node , compiler , connection , as_expr = True )
51+ try :
52+ inner_expression = self .filter .as_mql (compiler , connection , as_expr = True )
53+ except NotSupportedError :
54+ source_expressions = self .get_source_expressions ()
55+ condition = When (
56+ self .filter .condition ,
57+ then = Case (When (IsNull (source_expressions [0 ], False ), then = Value (1 ))),
58+ )
59+ inner_expression = Case (condition ).as_mql (compiler , connection , as_expr = True )
60+ except FullResultSet :
61+ inner_expression = {"$sum" : 1 }
62+ except EmptyResultSet :
63+ inner_expression = {"$sum" : 0 }
5964 else :
6065 lhs_mql = process_lhs (self , compiler , connection , as_expr = True )
6166 null_cond = {"$in" : [{"$type" : lhs_mql }, ["missing" , "null" ]]}
@@ -87,7 +92,7 @@ def string_agg(self, compiler, connection): # noqa: ARG001
8792
8893def register_aggregates ():
8994 Aggregate .as_mql_expr = aggregate
90- AggregateFilter .as_mql_expr = aggregate_filter
95+ # AggregateFilter.as_mql_expr = aggregate_filter
9196 Count .as_mql_expr = count
9297 StdDev .as_mql_expr = stddev_variance
9398 StringAgg .as_mql_expr = string_agg
0 commit comments