1111import  java .util .concurrent .atomic .AtomicReference ;
1212import  java .util .function .Function ;
1313
14- import  org .slf4j .Logger ;
15- import  org .slf4j .LoggerFactory ;
16- 
1714import  io .modelcontextprotocol .spec .McpClientSession ;
1815import  io .modelcontextprotocol .spec .McpError ;
1916import  io .modelcontextprotocol .spec .McpSchema ;
2017import  io .modelcontextprotocol .spec .McpTransportSessionNotFoundException ;
2118import  io .modelcontextprotocol .util .Assert ;
19+ import  org .slf4j .Logger ;
20+ import  org .slf4j .LoggerFactory ;
2221import  reactor .core .publisher .Mono ;
2322import  reactor .core .publisher .Sinks ;
2423import  reactor .util .context .ContextView ;
@@ -99,21 +98,30 @@ class LifecycleInitializer {
9998	 */ 
10099	private  final  Duration  initializationTimeout ;
101100
101+ 	/** 
102+ 	 * Post-initialization hook to perform additional operations after every successful 
103+ 	 * initialization. 
104+ 	 */ 
105+ 	private  final  Function <Initialization , Mono <Void >> postInitializationHook ;
106+ 
102107	public  LifecycleInitializer (McpSchema .ClientCapabilities  clientCapabilities , McpSchema .Implementation  clientInfo ,
103108			List <String > protocolVersions , Duration  initializationTimeout ,
104- 			Function <ContextView , McpClientSession > sessionSupplier ) {
109+ 			Function <ContextView , McpClientSession > sessionSupplier ,
110+ 			Function <Initialization , Mono <Void >> postInitializationHook ) {
105111
106112		Assert .notNull (sessionSupplier , "Session supplier must not be null" );
107113		Assert .notNull (clientCapabilities , "Client capabilities must not be null" );
108114		Assert .notNull (clientInfo , "Client info must not be null" );
109115		Assert .notEmpty (protocolVersions , "Protocol versions must not be empty" );
110116		Assert .notNull (initializationTimeout , "Initialization timeout must not be null" );
117+ 		Assert .notNull (postInitializationHook , "Post-initialization hook must not be null" );
111118
112119		this .sessionSupplier  = sessionSupplier ;
113120		this .clientCapabilities  = clientCapabilities ;
114121		this .clientInfo  = clientInfo ;
115122		this .protocolVersions  = Collections .unmodifiableList (new  ArrayList <>(protocolVersions ));
116123		this .initializationTimeout  = initializationTimeout ;
124+ 		this .postInitializationHook  = postInitializationHook ;
117125	}
118126
119127	/** 
@@ -148,10 +156,6 @@ interface Initialization {
148156
149157	}
150158
151- 	/** 
152- 	 * Default implementation of the {@link Initialization} interface that manages the MCP 
153- 	 * client initialization process. 
154- 	 */ 
155159	private  static  class  DefaultInitialization  implements  Initialization  {
156160
157161		/** 
@@ -199,29 +203,20 @@ private void setMcpClientSession(McpClientSession mcpClientSession) {
199203			this .mcpClientSession .set (mcpClientSession );
200204		}
201205
202- 		/** 
203- 		 * Returns a Mono that completes when the MCP client initialization is complete. 
204- 		 * This allows subscribers to wait for the initialization to finish before 
205- 		 * proceeding with further operations. 
206- 		 * @return A Mono that emits the result of the MCP initialization process 
207- 		 */ 
208206		private  Mono <McpSchema .InitializeResult > await () {
209207			return  this .initSink .asMono ();
210208		}
211209
212- 		/** 
213- 		 * Completes the initialization process with the given result. It caches the 
214- 		 * result and emits it to all subscribers waiting for the initialization to 
215- 		 * complete. 
216- 		 * @param initializeResult The result of the MCP initialization process 
217- 		 */ 
218210		private  void  complete (McpSchema .InitializeResult  initializeResult ) {
219- 			// first ensure the result is cached 
220- 			this .result .set (initializeResult );
221211			// inform all the subscribers waiting for the initialization 
222212			this .initSink .emitValue (initializeResult , Sinks .EmitFailureHandler .FAIL_FAST );
223213		}
224214
215+ 		private  void  cacheResult (McpSchema .InitializeResult  initializeResult ) {
216+ 			// first ensure the result is cached 
217+ 			this .result .set (initializeResult );
218+ 		}
219+ 
225220		private  void  error (Throwable  t ) {
226221			this .initSink .emitError (t , Sinks .EmitFailureHandler .FAIL_FAST );
227222		}
@@ -263,7 +258,7 @@ public void handleException(Throwable t) {
263258			}
264259			// Providing an empty operation since we are only interested in triggering 
265260			// the implicit initialization step. 
266- 			withIntitialization ("re-initializing" , result  -> Mono .empty ()).subscribe ();
261+ 			this . withInitialization ("re-initializing" , result  -> Mono .empty ()).subscribe ();
267262		}
268263	}
269264
@@ -275,16 +270,16 @@ public void handleException(Throwable t) {
275270	 * @param operation The operation to execute when the client is initialized 
276271	 * @return A Mono that completes with the result of the operation 
277272	 */ 
278- 	public  <T > Mono <T > withIntitialization (String  actionName , Function <Initialization , Mono <T >> operation ) {
273+ 	public  <T > Mono <T > withInitialization (String  actionName , Function <Initialization , Mono <T >> operation ) {
279274		return  Mono .deferContextual (ctx  -> {
280275			DefaultInitialization  newInit  = new  DefaultInitialization ();
281276			DefaultInitialization  previous  = this .initializationRef .compareAndExchange (null , newInit );
282277
283278			boolean  needsToInitialize  = previous  == null ;
284279			logger .debug (needsToInitialize  ? "Initialization process started"  : "Joining previous initialization" );
285280
286- 			Mono <McpSchema .InitializeResult > initializationJob  = needsToInitialize  ?  doInitialize ( newInit ,  ctx ) 
287- 					: previous .await ();
281+ 			Mono <McpSchema .InitializeResult > initializationJob  = needsToInitialize 
282+ 					?  this . doInitialize ( newInit ,  this . postInitializationHook ,  ctx )  : previous .await ();
288283
289284			return  initializationJob .map (initializeResult  -> this .initializationRef .get ())
290285				.timeout (this .initializationTimeout )
@@ -296,7 +291,9 @@ public <T> Mono<T> withIntitialization(String actionName, Function<Initializatio
296291		});
297292	}
298293
299- 	private  Mono <McpSchema .InitializeResult > doInitialize (DefaultInitialization  initialization , ContextView  ctx ) {
294+ 	private  Mono <McpSchema .InitializeResult > doInitialize (DefaultInitialization  initialization ,
295+ 			Function <Initialization , Mono <Void >> postInitOperation , ContextView  ctx ) {
296+ 
300297		initialization .setMcpClientSession (this .sessionSupplier .apply (ctx ));
301298
302299		McpClientSession  mcpClientSession  = initialization .mcpSession ();
@@ -323,6 +320,9 @@ private Mono<McpSchema.InitializeResult> doInitialize(DefaultInitialization init
323320
324321			return  mcpClientSession .sendNotification (McpSchema .METHOD_NOTIFICATION_INITIALIZED , null )
325322				.thenReturn (initializeResult );
323+ 		}).flatMap (initializeResult  -> {
324+ 			initialization .cacheResult (initializeResult );
325+ 			return  postInitOperation .apply (initialization ).thenReturn (initializeResult );
326326		}).doOnNext (initialization ::complete ).onErrorResume (ex  -> {
327327			initialization .error (ex );
328328			return  Mono .error (ex );
0 commit comments