11from django .db .models .aggregates import Aggregate , Count , StdDev , Variance
22from django .db .models .expressions import Case , Value , When
33from django .db .models .lookups import IsNull
4+ from django .db .models .sql .where import WhereNode
45
5- from . query_utils import process_lhs
6+ from django_mongodb_backend . expressions import Remove
67
78# Aggregates whose MongoDB aggregation name differ from Aggregate.function.lower().
89MONGO_AGGREGATIONS = {Count : "sum" }
910
1011
1112def aggregate (self , compiler , connection , operator = None , resolve_inner_expression = False ):
13+ agg_expression , * _ = self .get_source_expressions ()
1214 if self .filter :
13- node = self .copy ()
14- node .filter = None
15- source_expressions = node .get_source_expressions ()
16- condition = When (self .filter , then = source_expressions [0 ])
17- node .set_source_expressions ([Case (condition ), * source_expressions [1 :]])
18- else :
19- node = self
20- lhs_mql = process_lhs (node , compiler , connection , as_expr = True )
15+ agg_expression = Case (
16+ When (self .filter , then = agg_expression ),
17+ # Skip rows that don't meet the criteria.
18+ default = Remove (),
19+ )
20+ lhs_mql = agg_expression .as_mql (compiler , connection , as_expr = True )
2121 if resolve_inner_expression :
2222 return lhs_mql
2323 operator = operator or MONGO_AGGREGATIONS .get (self .__class__ , self .function .lower ())
@@ -30,31 +30,23 @@ def count(self, compiler, connection, resolve_inner_expression=False):
3030 value. This is used to count different elements, so the inner values are
3131 returned to be pushed into a set.
3232 """
33+ agg_expression , * _ = self .get_source_expressions ()
3334 if not self .distinct or resolve_inner_expression :
35+ conditions = [IsNull (agg_expression , False )]
3436 if self .filter :
35- node = self .copy ()
36- node .filter = None
37- source_expressions = node .get_source_expressions ()
38- condition = When (
39- self .filter , then = Case (When (IsNull (source_expressions [0 ], False ), then = Value (1 )))
40- )
41- node .set_source_expressions ([Case (condition ), * source_expressions [1 :]])
42- inner_expression = process_lhs (node , compiler , connection , as_expr = True )
43- else :
44- lhs_mql = process_lhs (self , compiler , connection , as_expr = True )
45- null_cond = {"$in" : [{"$type" : lhs_mql }, ["missing" , "null" ]]}
46- inner_expression = {
47- "$cond" : {"if" : null_cond , "then" : None , "else" : lhs_mql if self .distinct else 1 }
48- }
37+ conditions .append (self .filter )
38+ inner_expression = Case (
39+ When (WhereNode (conditions ), then = agg_expression if self .distinct else Value (1 )),
40+ # Skip rows that don't meet the criteria.
41+ default = Remove (),
42+ )
43+ inner_expression = inner_expression .as_mql (compiler , connection , as_expr = True )
4944 if resolve_inner_expression :
5045 return inner_expression
5146 return {"$sum" : inner_expression }
5247 # If distinct=True or resolve_inner_expression=False, sum the size of the
5348 # set.
54- lhs_mql = process_lhs (self , compiler , connection , as_expr = True )
55- # None shouldn't be counted, so subtract 1 if it's present.
56- exits_null = {"$cond" : {"if" : {"$in" : [{"$literal" : None }, lhs_mql ]}, "then" : - 1 , "else" : 0 }}
57- return {"$add" : [{"$size" : lhs_mql }, exits_null ]}
49+ return {"$size" : agg_expression .as_mql (compiler , connection , as_expr = True )}
5850
5951
6052def stddev_variance (self , compiler , connection ):
0 commit comments