2222
2323import  org .bson .Document ;
2424import  org .jspecify .annotations .NullUnmarked ;
25- 
2625import  org .springframework .core .ResolvableType ;
2726import  org .springframework .core .annotation .MergedAnnotation ;
2827import  org .springframework .data .domain .SliceImpl ;
2928import  org .springframework .data .domain .Sort .Order ;
3029import  org .springframework .data .mongodb .core .MongoOperations ;
3130import  org .springframework .data .mongodb .core .aggregation .Aggregation ;
31+ import  org .springframework .data .mongodb .core .aggregation .AggregationOperation ;
3232import  org .springframework .data .mongodb .core .aggregation .AggregationOptions ;
3333import  org .springframework .data .mongodb .core .aggregation .AggregationPipeline ;
3434import  org .springframework .data .mongodb .core .aggregation .AggregationResults ;
@@ -80,12 +80,7 @@ CodeBlock build() {
8080
8181			builder .add ("\n " );
8282
83- 			Class <?> outputType  = queryMethod .getReturnedObjectType ();
84- 			if  (MongoSimpleTypes .HOLDER .isSimpleType (outputType )) {
85- 				outputType  = Document .class ;
86- 			} else  if  (ClassUtils .isAssignable (AggregationResults .class , outputType )) {
87- 				outputType  = queryMethod .getReturnType ().getComponentType ().getType ();
88- 			}
83+ 			Class <?> outputType  = getOutputType (queryMethod );
8984
9085			if  (ReflectionUtils .isVoid (queryMethod .getReturnedObjectType ())) {
9186				builder .addStatement ("$L.aggregate($L, $T.class)" , mongoOpsRef , aggregationVariableName , outputType );
@@ -146,7 +141,6 @@ CodeBlock build() {
146141						builder .addStatement ("return $L.aggregateStream($L, $T.class)" , mongoOpsRef , aggregationVariableName ,
147142								outputType );
148143					} else  {
149- 
150144						builder .addStatement ("return $L.aggregate($L, $T.class).getMappedResults()" , mongoOpsRef ,
151145								aggregationVariableName , outputType );
152146					}
@@ -155,6 +149,17 @@ CodeBlock build() {
155149
156150			return  builder .build ();
157151		}
152+ 
153+ 	}
154+ 
155+ 	private  static  Class <?> getOutputType (MongoQueryMethod  queryMethod ) {
156+ 		Class <?> outputType  = queryMethod .getReturnedObjectType ();
157+ 		if  (MongoSimpleTypes .HOLDER .isSimpleType (outputType )) {
158+ 			outputType  = Document .class ;
159+ 		} else  if  (ClassUtils .isAssignable (AggregationResults .class , outputType ) && queryMethod .getReturnType ().getComponentType () != null ) {
160+ 			outputType  = queryMethod .getReturnType ().getComponentType ().getType ();
161+ 		}
162+ 		return  outputType ;
158163	}
159164
160165	@ NullUnmarked 
@@ -173,13 +178,7 @@ static class AggregationCodeBlockBuilder {
173178
174179			this .context  = context ;
175180			this .queryMethod  = queryMethod ;
176- 			String  parameterNames  = StringUtils .collectionToDelimitedString (context .getAllParameterNames (), ", " );
177- 
178- 			if  (StringUtils .hasText (parameterNames )) {
179- 				this .parameterNames  = ", "  + parameterNames ;
180- 			} else  {
181- 				this .parameterNames  = "" ;
182- 			}
181+ 			this .parameterNames  = StringUtils .collectionToDelimitedString (context .getAllParameterNames (), ", " );
183182		}
184183
185184		AggregationCodeBlockBuilder  stages (AggregationInteraction  aggregation ) {
@@ -231,7 +230,8 @@ private CodeBlock pipeline(String pipelineVariableName) {
231230			builder .add (aggregationStages (context .localVariable ("stages" ), source .stages ()));
232231
233232			if  (StringUtils .hasText (sortParameter )) {
234- 				builder .add (sortingStage (sortParameter ));
233+ 				Class <?> outputType  = getOutputType (queryMethod );
234+ 				builder .add (sortingStage (sortParameter , outputType ));
235235			}
236236
237237			if  (StringUtils .hasText (limitParameter )) {
@@ -244,6 +244,7 @@ private CodeBlock pipeline(String pipelineVariableName) {
244244
245245			builder .addStatement ("$T $L = createPipeline($L)" , AggregationPipeline .class , pipelineVariableName ,
246246					context .localVariable ("stages" ));
247+ 
247248			return  builder .build ();
248249		}
249250
@@ -312,7 +313,7 @@ private CodeBlock aggregationStages(String stageListVariableName, Collection<Str
312313			return  builder .build ();
313314		}
314315
315- 		private  CodeBlock  sortingStage (String  sortProvider ) {
316+ 		private  CodeBlock  sortingStage (String  sortProvider ,  Class <?>  outputType ) {
316317
317318			Builder  builder  = CodeBlock .builder ();
318319
@@ -322,8 +323,17 @@ private CodeBlock sortingStage(String sortProvider) {
322323			builder .addStatement ("$1L.append($2L.getProperty(), $2L.isAscending() ? 1 : -1);" ,
323324					context .localVariable ("sortDocument" ), context .localVariable ("order" ));
324325			builder .endControlFlow ();
325- 			builder .addStatement ("stages.add(new $T($S, $L))" , Document .class , "$sort" ,
326- 					context .localVariable ("sortDocument" ));
326+ 
327+ 			if  (outputType  == Document .class  || MongoSimpleTypes .HOLDER .isSimpleType (outputType )
328+ 					|| ClassUtils .isAssignable (context .getRepositoryInformation ().getDomainType (), outputType )) {
329+ 				builder .addStatement ("$L.add(new $T($S, $L))" , context .localVariable ("stages" ), Document .class , "$sort" ,
330+ 						context .localVariable ("sortDocument" ));
331+ 			} else  {
332+ 				builder .addStatement ("$L.add(($T) _ctx -> new $T($S, _ctx.getMappedObject($L, $T.class)))" ,
333+ 						context .localVariable ("stages" ), AggregationOperation .class , Document .class , "$sort" ,
334+ 						context .localVariable ("sortDocument" ), outputType );
335+ 			}
336+ 
327337			builder .endControlFlow ();
328338
329339			return  builder .build ();
@@ -333,7 +343,7 @@ private CodeBlock pagingStage(String pageableProvider, boolean slice) {
333343
334344			Builder  builder  = CodeBlock .builder ();
335345
336- 			builder .add (sortingStage (pageableProvider  + ".getSort()" ));
346+ 			builder .add (sortingStage (pageableProvider  + ".getSort()" ,  getOutputType ( queryMethod ) ));
337347
338348			builder .beginControlFlow ("if ($L.isPaged())" , pageableProvider );
339349			builder .beginControlFlow ("if ($L.getOffset() > 0)" , pageableProvider );
0 commit comments