1616package org .springframework .data .jdbc .core ;
1717
1818import java .util .Optional ;
19+ import java .util .function .Function ;
20+ import java .util .stream .Collectors ;
21+ import java .util .stream .StreamSupport ;
1922
2023import org .springframework .context .ApplicationEventPublisher ;
2124import org .springframework .data .jdbc .core .convert .DataAccessStrategy ;
2225import org .springframework .data .mapping .IdentifierAccessor ;
26+ import org .springframework .data .mapping .callback .EntityCallback ;
27+ import org .springframework .data .mapping .callback .EntityCallbacks ;
2328import org .springframework .data .relational .core .conversion .AggregateChange ;
2429import org .springframework .data .relational .core .conversion .AggregateChange .Kind ;
2530import org .springframework .data .relational .core .conversion .Interpreter ;
3035import org .springframework .data .relational .core .conversion .RelationalEntityWriter ;
3136import org .springframework .data .relational .core .mapping .RelationalMappingContext ;
3237import org .springframework .data .relational .core .mapping .RelationalPersistentEntity ;
33- import org .springframework .data .relational .core .mapping .event .AfterDeleteEvent ;
34- import org .springframework .data .relational .core .mapping .event .AfterLoadEvent ;
35- import org .springframework .data .relational .core .mapping .event .AfterSaveEvent ;
36- import org .springframework .data .relational .core .mapping .event .BeforeDeleteEvent ;
37- import org .springframework .data .relational .core .mapping .event .BeforeSaveEvent ;
38- import org .springframework .data .relational .core .mapping .event .Identifier ;
38+ import org .springframework .data .relational .core .mapping .event .*;
3939import org .springframework .data .relational .core .mapping .event .Identifier .Specified ;
4040import org .springframework .lang .Nullable ;
4141import org .springframework .util .Assert ;
@@ -62,6 +62,8 @@ public class JdbcAggregateTemplate implements JdbcAggregateOperations {
6262
6363 private final DataAccessStrategy accessStrategy ;
6464
65+ private EntityCallbacks entityCallbacks = NoopEntityCallback .INSTANCE ;
66+
6567 /**
6668 * Creates a new {@link JdbcAggregateTemplate} given {@link ApplicationEventPublisher},
6769 * {@link RelationalMappingContext} and {@link DataAccessStrategy}.
@@ -90,6 +92,13 @@ public JdbcAggregateTemplate(ApplicationEventPublisher publisher, RelationalMapp
9092 this .interpreter = new DefaultJdbcInterpreter (context , accessStrategy );
9193 }
9294
95+ public void setEntityCallbacks (EntityCallbacks entityCallbacks ) {
96+
97+ Assert .notNull (entityCallbacks , "Callbacks must not be null." );
98+
99+ this .entityCallbacks = entityCallbacks ;
100+ }
101+
93102 /*
94103 * (non-Javadoc)
95104 * @see org.springframework.data.jdbc.core.JdbcAggregateOperations#save(java.lang.Object)
@@ -100,11 +109,11 @@ public <T> T save(T instance) {
100109 Assert .notNull (instance , "Aggregate instance must not be null!" );
101110
102111 RelationalPersistentEntity <?> persistentEntity = context .getRequiredPersistentEntity (instance .getClass ());
103- IdentifierAccessor identifierAccessor = persistentEntity .getIdentifierAccessor (instance );
104112
105- AggregateChange <T > change = createChange (instance );
113+ Function <T , AggregateChange <T >> changeCreator = persistentEntity .isNew (instance ) ? this ::createInsertChange
114+ : this ::createUpdateChange ;
106115
107- return store (instance , identifierAccessor , change , persistentEntity );
116+ return store (instance , changeCreator , persistentEntity );
108117 }
109118
110119 /**
@@ -120,11 +129,8 @@ public <T> T insert(T instance) {
120129 Assert .notNull (instance , "Aggregate instance must not be null!" );
121130
122131 RelationalPersistentEntity <?> persistentEntity = context .getRequiredPersistentEntity (instance .getClass ());
123- IdentifierAccessor identifierAccessor = persistentEntity .getIdentifierAccessor (instance );
124-
125- AggregateChange <T > change = createInsertChange (instance );
126132
127- return store (instance , identifierAccessor , change , persistentEntity );
133+ return store (instance , this :: createInsertChange , persistentEntity );
128134 }
129135
130136 /**
@@ -140,11 +146,8 @@ public <T> T update(T instance) {
140146 Assert .notNull (instance , "Aggregate instance must not be null!" );
141147
142148 RelationalPersistentEntity <?> persistentEntity = context .getRequiredPersistentEntity (instance .getClass ());
143- IdentifierAccessor identifierAccessor = persistentEntity .getIdentifierAccessor (instance );
144149
145- AggregateChange <T > change = createUpdateChange (instance );
146-
147- return store (instance , identifierAccessor , change , persistentEntity );
150+ return store (instance , this ::createUpdateChange , persistentEntity );
148151 }
149152
150153 /*
@@ -171,7 +174,7 @@ public <T> T findById(Object id, Class<T> domainType) {
171174
172175 T entity = accessStrategy .findById (id , domainType );
173176 if (entity != null ) {
174- publishAfterLoad (id , entity );
177+ return triggerAfterLoad (id , entity );
175178 }
176179 return entity ;
177180 }
@@ -199,8 +202,7 @@ public <T> Iterable<T> findAll(Class<T> domainType) {
199202 Assert .notNull (domainType , "Domain type must not be null!" );
200203
201204 Iterable <T > all = accessStrategy .findAll (domainType );
202- publishAfterLoad (all );
203- return all ;
205+ return triggerAfterLoad (all );
204206 }
205207
206208 /*
@@ -214,8 +216,7 @@ public <T> Iterable<T> findAllById(Iterable<?> ids, Class<T> domainType) {
214216 Assert .notNull (domainType , "Domain type must not be null!" );
215217
216218 Iterable <T > allById = accessStrategy .findAllById (ids , domainType );
217- publishAfterLoad (allById );
218- return allById ;
219+ return triggerAfterLoad (allById );
219220 }
220221
221222 /*
@@ -260,48 +261,47 @@ public void deleteAll(Class<?> domainType) {
260261 change .executeWith (interpreter , context , converter );
261262 }
262263
263- private <T > T store (T instance , IdentifierAccessor identifierAccessor , AggregateChange <T > change ,
264+ private <T > T store (T aggregateRoot , Function < T , AggregateChange <T >> changeCreator ,
264265 RelationalPersistentEntity <?> persistentEntity ) {
265266
266- Assert .notNull (instance , "Aggregate instance must not be null!" );
267+ Assert .notNull (aggregateRoot , "Aggregate instance must not be null!" );
267268
268- publisher .publishEvent (new BeforeSaveEvent ( //
269- Identifier .ofNullable (identifierAccessor .getIdentifier ()), //
270- instance , //
271- change //
272- ));
269+ aggregateRoot = triggerBeforeConvert (aggregateRoot ,
270+ persistentEntity .getIdentifierAccessor (aggregateRoot ).getIdentifier ());
271+
272+ AggregateChange <T > change = changeCreator .apply (aggregateRoot );
273+
274+ aggregateRoot = triggerBeforeSave (aggregateRoot ,
275+ persistentEntity .getIdentifierAccessor (aggregateRoot ).getIdentifier (), change );
276+
277+ change .setEntity (aggregateRoot );
273278
274279 change .executeWith (interpreter , context , converter );
275280
276281 Object identifier = persistentEntity .getIdentifierAccessor (change .getEntity ()).getIdentifier ();
277282
278283 Assert .notNull (identifier , "After saving the identifier must not be null!" );
279284
280- publisher .publishEvent (new AfterSaveEvent ( //
281- Identifier .of (identifier ), //
282- change .getEntity (), //
283- change //
284- ));
285-
286- return (T ) change .getEntity ();
285+ return triggerAfterSave (change .getEntity (), identifier , change );
287286 }
288287
289- private void deleteTree (Object id , @ Nullable Object entity , Class <? > domainType ) {
288+ private < T > void deleteTree (Object id , @ Nullable T entity , Class <T > domainType ) {
290289
291- AggregateChange <? > change = createDeletingChange (id , entity , domainType );
290+ AggregateChange <T > change = createDeletingChange (id , entity , domainType );
292291
293- Specified specifiedId = Identifier .of (id );
294- Optional <Object > optionalEntity = Optional .ofNullable (entity );
295- publisher .publishEvent (new BeforeDeleteEvent (specifiedId , optionalEntity , change ));
292+ entity = triggerBeforeDelete (entity , id , change );
293+ change .setEntity (entity );
296294
297295 change .executeWith (interpreter , context , converter );
298296
299- publisher . publishEvent ( new AfterDeleteEvent ( specifiedId , optionalEntity , change ) );
297+ triggerAfterDelete ( entity , id , change );
300298 }
301299
302300 @ SuppressWarnings ({ "unchecked" , "rawtypes" })
303301 private <T > AggregateChange <T > createChange (T instance ) {
304302
303+ // context.getRequiredPersistentEntity(o.getClass()).isNew(o)
304+
305305 AggregateChange <T > aggregateChange = new AggregateChange (Kind .SAVE , instance .getClass (), instance );
306306 jdbcEntityWriter .write (instance , aggregateChange );
307307 return aggregateChange ;
@@ -324,9 +324,9 @@ private <T> AggregateChange<T> createUpdateChange(T instance) {
324324 }
325325
326326 @ SuppressWarnings ({ "unchecked" , "rawtypes" })
327- private AggregateChange <? > createDeletingChange (Object id , @ Nullable Object entity , Class <? > domainType ) {
327+ private < T > AggregateChange <T > createDeletingChange (Object id , @ Nullable T entity , Class <T > domainType ) {
328328
329- AggregateChange <? > aggregateChange = new AggregateChange (Kind .DELETE , domainType , entity );
329+ AggregateChange <T > aggregateChange = new AggregateChange (Kind .DELETE , domainType , entity );
330330 jdbcEntityDeleteWriter .write (id , aggregateChange );
331331 return aggregateChange ;
332332 }
@@ -338,18 +338,96 @@ private AggregateChange<?> createDeletingChange(Class<?> domainType) {
338338 return aggregateChange ;
339339 }
340340
341- private <T > void publishAfterLoad (Iterable <T > all ) {
341+ private <T > Iterable < T > triggerAfterLoad (Iterable <T > all ) {
342342
343- for ( T e : all ) {
343+ return StreamSupport . stream ( all . spliterator (), false ). map ( e -> {
344344
345345 RelationalPersistentEntity <?> entity = context .getRequiredPersistentEntity (e .getClass ());
346346 IdentifierAccessor identifierAccessor = entity .getIdentifierAccessor (e );
347347
348- publishAfterLoad (identifierAccessor .getRequiredIdentifier (), e );
349- }
348+ return triggerAfterLoad (identifierAccessor .getRequiredIdentifier (), e );
349+ }). collect ( Collectors . toList ());
350350 }
351351
352- private <T > void publishAfterLoad (Object id , T entity ) {
352+ private <T > T triggerAfterLoad (Object id , T entity ) {
353+
353354 publisher .publishEvent (new AfterLoadEvent (Identifier .of (id ), entity ));
355+
356+ return entityCallbacks .callback (AfterLoadCallback .class , entity , Identifier .of (id ));
357+ }
358+
359+ private <T > T triggerBeforeConvert (T aggregateRoot , @ Nullable Object id ) {
360+
361+ Identifier identifier = Identifier .ofNullable (id );
362+
363+ return entityCallbacks .callback (BeforeConvertCallback .class , aggregateRoot , identifier );
364+ }
365+
366+ private <T > T triggerBeforeSave (T aggregateRoot , @ Nullable Object id , AggregateChange <T > change ) {
367+
368+ Identifier identifier = Identifier .ofNullable (id );
369+
370+ publisher .publishEvent (new BeforeSaveEvent ( //
371+ identifier , //
372+ aggregateRoot , //
373+ change //
374+ ));
375+
376+ return entityCallbacks .callback (BeforeSaveCallback .class , aggregateRoot , identifier );
377+ }
378+
379+ private <T > T triggerAfterSave (T aggregateRoot , Object id , AggregateChange <T > change ) {
380+
381+ Specified identifier = Identifier .of (id );
382+
383+ publisher .publishEvent (new AfterSaveEvent ( //
384+ identifier , //
385+ aggregateRoot , //
386+ change //
387+ ));
388+
389+ return entityCallbacks .callback (AfterSaveCallback .class , aggregateRoot , identifier );
390+ }
391+
392+ @ Nullable
393+ private <T > void triggerAfterDelete (@ Nullable T aggregateRoot , Object id , AggregateChange <?> change ) {
394+
395+ Specified identifier = Identifier .of (id );
396+
397+ publisher .publishEvent (new AfterDeleteEvent (identifier , Optional .ofNullable (aggregateRoot ), change ));
398+
399+ if (aggregateRoot != null ) {
400+ entityCallbacks .callback (AfterDeleteCallback .class , aggregateRoot , identifier );
401+ }
402+ }
403+
404+ @ Nullable
405+ private <T > T triggerBeforeDelete (@ Nullable T aggregateRoot , Object id , AggregateChange <?> change ) {
406+
407+ Specified identifier = Identifier .of (id );
408+
409+ publisher .publishEvent (new BeforeDeleteEvent (identifier , Optional .ofNullable (aggregateRoot ), change ));
410+
411+ if (aggregateRoot != null ) {
412+ return entityCallbacks .callback (BeforeDeleteCallback .class , aggregateRoot , identifier );
413+ }
414+ return aggregateRoot ;
415+ }
416+
417+ /**
418+ * An {@link EntityCallbacks} implementation doing nothing.
419+ */
420+ private enum NoopEntityCallback implements EntityCallbacks {
421+
422+ INSTANCE {
423+
424+ @ Override
425+ public void addEntityCallback (EntityCallback <?> callback ) {}
426+
427+ @ Override
428+ public <T > T callback (Class <? extends EntityCallback > callbackType , T entity , Object ... args ) {
429+ return entity ;
430+ }
431+ }
354432 }
355433}
0 commit comments