Skip to content

Commit 2e6c87c

Browse files
committed
Aggregation rework
1 parent 17592fc commit 2e6c87c

File tree

1 file changed

+28
-23
lines changed

1 file changed

+28
-23
lines changed

django_mongodb_backend/aggregates.py

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1+
from django.core.exceptions import EmptyResultSet, FullResultSet
12
from django.db import NotSupportedError
23
from django.db.models.aggregates import (
34
Aggregate,
4-
AggregateFilter,
55
Count,
66
StdDev,
77
StringAgg,
88
Variance,
99
)
10-
from django.db.models.expressions import Case, Col, Value, When
10+
from django.db.models.expressions import Case, Value, When
1111
from django.db.models.lookups import IsNull
1212

1313
from .query_utils import process_lhs
@@ -20,26 +20,26 @@ def aggregate(self, compiler, connection, operator=None, resolve_inner_expressio
2020
# TODO: isinstance(self.filter, Col) works around failure of
2121
# aggregation.tests.AggregateTestCase.test_distinct_on_aggregate. Is this
2222
# correct?
23-
if self.filter is not None and not isinstance(self.filter, Col):
23+
if self.filter is not None:
2424
# Generate a CASE statement for this aggregate.
25-
node = self.copy()
26-
node.filter = None
27-
source_expressions = node.get_source_expressions()
28-
condition = When(self.filter, then=source_expressions[0])
29-
node.set_source_expressions([Case(condition), *source_expressions[1:]])
25+
try:
26+
lhs_mql = self.filter.as_mql(compiler, connection, as_expr=True)
27+
except NotSupportedError:
28+
source_expressions = self.get_source_expressions()
29+
condition = Case(When(self.filter.condition, then=source_expressions[0]))
30+
lhs_mql = condition.as_mql(compiler, connection)
31+
except FullResultSet:
32+
lhs_mql = source_expressions[0].as_mql(compiler, connection, as_expr=True)
33+
except EmptyResultSet:
34+
lhs_mql = Value(None).as_mql(compiler, connection, as_expr=True)
3035
else:
31-
node = self
32-
lhs_mql = process_lhs(node, compiler, connection, as_expr=True)
36+
lhs_mql = process_lhs(self, compiler, connection, as_expr=True)
3337
if resolve_inner_expression:
3438
return lhs_mql
3539
operator = operator or MONGO_AGGREGATIONS.get(self.__class__, self.function.lower())
3640
return {f"${operator}": lhs_mql}
3741

3842

39-
def aggregate_filter(self, compiler, connection):
40-
return self.condition.as_mql(compiler, connection, as_expr=True)
41-
42-
4343
def count(self, compiler, connection, resolve_inner_expression=False):
4444
"""
4545
When resolve_inner_expression=True, return the MQL that resolves as a
@@ -48,14 +48,19 @@ def count(self, compiler, connection, resolve_inner_expression=False):
4848
"""
4949
if not self.distinct or resolve_inner_expression:
5050
if self.filter:
51-
node = self.copy()
52-
node.filter = None
53-
source_expressions = node.get_source_expressions()
54-
condition = When(
55-
self.filter, then=Case(When(IsNull(source_expressions[0], False), then=Value(1)))
56-
)
57-
node.set_source_expressions([Case(condition), *source_expressions[1:]])
58-
inner_expression = process_lhs(node, compiler, connection, as_expr=True)
51+
try:
52+
inner_expression = self.filter.as_mql(compiler, connection, as_expr=True)
53+
except NotSupportedError:
54+
source_expressions = self.get_source_expressions()
55+
condition = When(
56+
self.filter.condition,
57+
then=Case(When(IsNull(source_expressions[0], False), then=Value(1))),
58+
)
59+
inner_expression = Case(condition).as_mql(compiler, connection, as_expr=True)
60+
except FullResultSet:
61+
inner_expression = {"$sum": 1}
62+
except EmptyResultSet:
63+
inner_expression = {"$sum": 0}
5964
else:
6065
lhs_mql = process_lhs(self, compiler, connection, as_expr=True)
6166
null_cond = {"$in": [{"$type": lhs_mql}, ["missing", "null"]]}
@@ -87,7 +92,7 @@ def string_agg(self, compiler, connection): # noqa: ARG001
8792

8893
def register_aggregates():
8994
Aggregate.as_mql_expr = aggregate
90-
AggregateFilter.as_mql_expr = aggregate_filter
95+
# AggregateFilter.as_mql_expr = aggregate_filter
9196
Count.as_mql_expr = count
9297
StdDev.as_mql_expr = stddev_variance
9398
StringAgg.as_mql_expr = string_agg

0 commit comments

Comments
 (0)