77from django .db import IntegrityError , NotSupportedError
88from django .db .models import Count
99from django .db .models .aggregates import Aggregate , Variance
10- from django .db .models .expressions import Case , Col , Ref , Value , When
10+ from django .db .models .expressions import Case , Col , OrderBy , Ref , Value , When
1111from django .db .models .functions .comparison import Coalesce
1212from django .db .models .functions .math import Power
1313from django .db .models .lookups import IsNull
@@ -32,6 +32,34 @@ def __init__(self, *args, **kwargs):
3232 # A list of OrderBy objects for this query.
3333 self .order_by_objs = None
3434
35+ def _unfold_column (self , col ):
36+ """
37+ Flatten a field by returning its target or by replacing dots with
38+ GROUP_SEPARATOR for foreign fields.
39+ """
40+ if self .collection_name == col .alias :
41+ return col .target .column
42+ # If this is a foreign field, replace the normal dot (.) with
43+ # GROUP_SEPARATOR since FieldPath field names may not contain '.'.
44+ return f"{ col .alias } { self .GROUP_SEPARATOR } { col .target .column } "
45+
46+ def _fold_columns (self , unfold_columns ):
47+ """
48+ Convert flat columns into a nested dictionary, grouping fields by
49+ table name.
50+ """
51+ result = defaultdict (dict )
52+ for key in unfold_columns :
53+ value = f"$_id.{ key } "
54+ if self .GROUP_SEPARATOR in key :
55+ table , field = key .split (self .GROUP_SEPARATOR )
56+ result [table ][field ] = value
57+ else :
58+ result [key ] = value
59+ # Convert defaultdict to dict so it doesn't appear as
60+ # "defaultdict(<CLASS 'dict'>, ..." in query logging.
61+ return dict (result )
62+
3563 def _get_group_alias_column (self , expr , annotation_group_idx ):
3664 """Generate a dummy field for use in the ids fields in $group."""
3765 replacement = None
@@ -42,11 +70,7 @@ def _get_group_alias_column(self, expr, annotation_group_idx):
4270 alias = f"__annotation_group{ next (annotation_group_idx )} "
4371 col = self ._get_column_from_expression (expr , alias )
4472 replacement = col
45- if self .collection_name == col .alias :
46- return col .target .column , replacement
47- # If this is a foreign field, replace the normal dot (.) with
48- # GROUP_SEPARATOR since FieldPath field names may not contain '.'.
49- return f"{ col .alias } { self .GROUP_SEPARATOR } { col .target .column } " , replacement
73+ return self ._unfold_column (col ), replacement
5074
5175 def _get_column_from_expression (self , expr , alias ):
5276 """
@@ -186,17 +210,8 @@ def _build_aggregation_pipeline(self, ids, group):
186210 else :
187211 group ["_id" ] = ids
188212 pipeline .append ({"$group" : group })
189- projected_fields = defaultdict (dict )
190- for key in ids :
191- value = f"$_id.{ key } "
192- if self .GROUP_SEPARATOR in key :
193- table , field = key .split (self .GROUP_SEPARATOR )
194- projected_fields [table ][field ] = value
195- else :
196- projected_fields [key ] = value
197- # Convert defaultdict to dict so it doesn't appear as
198- # "defaultdict(<CLASS 'dict'>, ..." in query logging.
199- pipeline .append ({"$addFields" : dict (projected_fields )})
213+ projected_fields = self ._fold_columns (ids )
214+ pipeline .append ({"$addFields" : projected_fields })
200215 if "_id" not in projected_fields :
201216 pipeline .append ({"$unset" : "_id" })
202217 return pipeline
@@ -349,23 +364,30 @@ def build_query(self, columns=None):
349364 """Check if the query is supported and prepare a MongoQuery."""
350365 self .check_query ()
351366 query = self .query_class (self )
352- query .lookup_pipeline = self .get_lookup_pipeline ()
353367 ordering_fields , sort_ordering , extra_fields = self ._get_ordering ()
354- query .project_fields = self .get_project_fields (columns , ordering_fields )
355368 query .ordering = sort_ordering
356- # If columns is None, then get_project_fields() won't add
357- # ordering_fields to $project. Use $addFields (extra_fields) instead.
358- if columns is None :
359- extra_fields += ordering_fields
369+ if self .query .combinator :
370+ if not getattr (self .connection .features , f"supports_select_{ self .query .combinator } " ):
371+ raise NotSupportedError (
372+ f"{ self .query .combinator } is not supported on this database backend."
373+ )
374+ query .combinator_pipeline = self .get_combinator_queries ()
375+ else :
376+ query .project_fields = self .get_project_fields (columns , ordering_fields )
377+ # If columns is None, then get_project_fields() won't add
378+ # ordering_fields to $project. Use $addFields (extra_fields) instead.
379+ if columns is None :
380+ extra_fields += ordering_fields
381+ query .lookup_pipeline = self .get_lookup_pipeline ()
382+ where = self .get_where ()
383+ try :
384+ expr = where .as_mql (self , self .connection ) if where else {}
385+ except FullResultSet :
386+ query .mongo_query = {}
387+ else :
388+ query .mongo_query = {"$expr" : expr }
360389 if extra_fields :
361390 query .extra_fields = self .get_project_fields (extra_fields , force_expression = True )
362- where = self .get_where ()
363- try :
364- expr = where .as_mql (self , self .connection ) if where else {}
365- except FullResultSet :
366- query .mongo_query = {}
367- else :
368- query .mongo_query = {"$expr" : expr }
369391 return query
370392
371393 def get_columns (self ):
@@ -391,6 +413,9 @@ def project_field(column):
391413 if hasattr (column , "target" ):
392414 # column is a Col.
393415 target = column .target .column
416+ # Handle Order By columns as refs columns.
417+ elif isinstance (column , OrderBy ) and isinstance (column .expression , Ref ):
418+ target = column .expression .refs
394419 else :
395420 # column is a Transform in values()/values_list() that needs a
396421 # name for $proj.
@@ -412,6 +437,75 @@ def collection_name(self):
412437 def collection (self ):
413438 return self .connection .get_collection (self .collection_name )
414439
440+ def get_combinator_queries (self ):
441+ parts = []
442+ compilers = [
443+ query .get_compiler (self .using , self .connection , self .elide_empty )
444+ for query in self .query .combined_queries
445+ ]
446+ main_query_columns = self .get_columns ()
447+ main_query_fields , _ = zip (* main_query_columns , strict = True )
448+ for compiler_ in compilers :
449+ try :
450+ # If the columns list is limited, then all combined queries
451+ # must have the same columns list. Set the selects defined on
452+ # the query on all combined queries, if not already set.
453+ if not compiler_ .query .values_select and self .query .values_select :
454+ compiler_ .query = compiler_ .query .clone ()
455+ compiler_ .query .set_values (
456+ (
457+ * self .query .extra_select ,
458+ * self .query .values_select ,
459+ * self .query .annotation_select ,
460+ )
461+ )
462+ compiler_ .pre_sql_setup ()
463+ columns = compiler_ .get_columns ()
464+ parts .append ((compiler_ .build_query (columns ), compiler_ , columns ))
465+ except EmptyResultSet :
466+ # Omit the empty queryset with UNION.
467+ if self .query .combinator == "union" :
468+ continue
469+ raise
470+ # Raise EmptyResultSet if all the combinator queries are empty.
471+ if not parts :
472+ raise EmptyResultSet
473+ # Make the combinator's stages.
474+ combinator_pipeline = None
475+ for part , compiler_ , columns in parts :
476+ inner_pipeline = part .get_pipeline ()
477+ # Standardize result fields.
478+ fields = {}
479+ # When a .count() is called, the main_query_field has length 1
480+ # otherwise it has the same length as columns.
481+ for alias , (ref , expr ) in zip (main_query_fields , columns , strict = False ):
482+ if isinstance (expr , Col ) and expr .alias != compiler_ .collection_name :
483+ fields [expr .alias ] = 1
484+ else :
485+ fields [alias ] = f"${ ref } " if alias != ref else 1
486+ inner_pipeline .append ({"$project" : fields })
487+ # Combine query with the current combinator pipeline.
488+ if combinator_pipeline :
489+ combinator_pipeline .append (
490+ {"$unionWith" : {"coll" : compiler_ .collection_name , "pipeline" : inner_pipeline }}
491+ )
492+ else :
493+ combinator_pipeline = inner_pipeline
494+ if not self .query .combinator_all :
495+ ids = {}
496+ for alias , expr in main_query_columns :
497+ # Unfold foreign fields.
498+ if isinstance (expr , Col ) and expr .alias != self .collection_name :
499+ ids [self ._unfold_column (expr )] = expr .as_mql (self , self .connection )
500+ else :
501+ ids [alias ] = f"${ alias } "
502+ combinator_pipeline .append ({"$group" : {"_id" : ids }})
503+ projected_fields = self ._fold_columns (ids )
504+ combinator_pipeline .append ({"$addFields" : projected_fields })
505+ if "_id" not in projected_fields :
506+ combinator_pipeline .append ({"$unset" : "_id" })
507+ return combinator_pipeline
508+
415509 def get_lookup_pipeline (self ):
416510 result = []
417511 for alias in tuple (self .query .alias_map ):
0 commit comments