Skip to content

Commit 4761579

Browse files
WaVEVtimgraham
authored andcommitted
Refactor Aggregation and Count implementations
Clean up logic for better alignment with changes in Django 6.0.
1 parent 590029b commit 4761579

File tree

4 files changed

+27
-30
lines changed

4 files changed

+27
-30
lines changed

django_mongodb_backend/aggregates.py

Lines changed: 19 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,23 @@
11
from django.db.models.aggregates import Aggregate, Count, StdDev, Variance
22
from django.db.models.expressions import Case, Value, When
33
from django.db.models.lookups import IsNull
4+
from django.db.models.sql.where import WhereNode
45

5-
from .query_utils import process_lhs
6+
from django_mongodb_backend.expressions import Remove
67

78
# Aggregates whose MongoDB aggregation name differ from Aggregate.function.lower().
89
MONGO_AGGREGATIONS = {Count: "sum"}
910

1011

1112
def aggregate(self, compiler, connection, operator=None, resolve_inner_expression=False):
13+
agg_expression, *_ = self.get_source_expressions()
1214
if self.filter:
13-
node = self.copy()
14-
node.filter = None
15-
source_expressions = node.get_source_expressions()
16-
condition = When(self.filter, then=source_expressions[0])
17-
node.set_source_expressions([Case(condition), *source_expressions[1:]])
18-
else:
19-
node = self
20-
lhs_mql = process_lhs(node, compiler, connection, as_expr=True)
15+
agg_expression = Case(
16+
When(self.filter, then=agg_expression),
17+
# Skip rows that don't meet the criteria.
18+
default=Remove(),
19+
)
20+
lhs_mql = agg_expression.as_mql(compiler, connection, as_expr=True)
2121
if resolve_inner_expression:
2222
return lhs_mql
2323
operator = operator or MONGO_AGGREGATIONS.get(self.__class__, self.function.lower())
@@ -30,31 +30,23 @@ def count(self, compiler, connection, resolve_inner_expression=False):
3030
value. This is used to count different elements, so the inner values are
3131
returned to be pushed into a set.
3232
"""
33+
agg_expression, *_ = self.get_source_expressions()
3334
if not self.distinct or resolve_inner_expression:
35+
conditions = [IsNull(agg_expression, False)]
3436
if self.filter:
35-
node = self.copy()
36-
node.filter = None
37-
source_expressions = node.get_source_expressions()
38-
condition = When(
39-
self.filter, then=Case(When(IsNull(source_expressions[0], False), then=Value(1)))
40-
)
41-
node.set_source_expressions([Case(condition), *source_expressions[1:]])
42-
inner_expression = process_lhs(node, compiler, connection, as_expr=True)
43-
else:
44-
lhs_mql = process_lhs(self, compiler, connection, as_expr=True)
45-
null_cond = {"$in": [{"$type": lhs_mql}, ["missing", "null"]]}
46-
inner_expression = {
47-
"$cond": {"if": null_cond, "then": None, "else": lhs_mql if self.distinct else 1}
48-
}
37+
conditions.append(self.filter)
38+
inner_expression = Case(
39+
When(WhereNode(conditions), then=agg_expression if self.distinct else Value(1)),
40+
# Skip rows that don't meet the criteria.
41+
default=Remove(),
42+
)
43+
inner_expression = inner_expression.as_mql(compiler, connection, as_expr=True)
4944
if resolve_inner_expression:
5045
return inner_expression
5146
return {"$sum": inner_expression}
5247
# If distinct=True or resolve_inner_expression=False, sum the size of the
5348
# set.
54-
lhs_mql = process_lhs(self, compiler, connection, as_expr=True)
55-
# None shouldn't be counted, so subtract 1 if it's present.
56-
exits_null = {"$cond": {"if": {"$in": [{"$literal": None}, lhs_mql]}, "then": -1, "else": 0}}
57-
return {"$add": [{"$size": lhs_mql}, exits_null]}
49+
return {"$size": agg_expression.as_mql(compiler, connection, as_expr=True)}
5850

5951

6052
def stddev_variance(self, compiler, connection):

django_mongodb_backend/expressions/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from .expressions import Remove
12
from .search import (
23
CombinedSearchExpression,
34
CompoundExpression,
@@ -21,6 +22,7 @@
2122
__all__ = [
2223
"CombinedSearchExpression",
2324
"CompoundExpression",
25+
"Remove",
2426
"SearchAutocomplete",
2527
"SearchEquals",
2628
"SearchExists",
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from django.db.models.expressions import Func
2+
3+
4+
class Remove(Func):
5+
def as_mql(self, compiler, connection, as_expr=False):
6+
return "$$REMOVE"

django_mongodb_backend/query_utils.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from django.core.exceptions import FullResultSet
22
from django.db.models import F
3-
from django.db.models.aggregates import Aggregate
43
from django.db.models.expressions import CombinedExpression, Func, Value
54
from django.db.models.sql.query import Query
65

@@ -20,8 +19,6 @@ def process_lhs(node, compiler, connection, as_expr=False):
2019
result.append(expr.as_mql(compiler, connection, as_expr=as_expr))
2120
except FullResultSet:
2221
result.append(Value(True).as_mql(compiler, connection, as_expr=as_expr))
23-
if isinstance(node, Aggregate):
24-
return result[0]
2522
return result
2623
# node is a Transform with just one source expression, aliased as "lhs".
2724
if is_direct_value(node.lhs):

0 commit comments

Comments
 (0)