Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import java.util.function.Consumer;
import java.util.function.Function;

import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport.Builder;
import io.modelcontextprotocol.client.transport.ResponseSubscribers.SseResponseEvent;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import io.modelcontextprotocol.client.transport.ResponseSubscribers.ResponseEvent;
Expand Down Expand Up @@ -116,6 +118,8 @@ public class HttpClientSseClientTransport implements McpClientTransport {
*/
private final McpAsyncHttpClientRequestCustomizer httpRequestCustomizer;

private final AtomicReference<Consumer<Void>> connectionClosedHandler = new AtomicReference<>();

/**
* Creates a new transport instance with custom HTTP client builder, object mapper,
* and headers.
Expand All @@ -129,7 +133,8 @@ public class HttpClientSseClientTransport implements McpClientTransport {
* @throws IllegalArgumentException if objectMapper, clientBuilder, or headers is null
*/
HttpClientSseClientTransport(HttpClient httpClient, HttpRequest.Builder requestBuilder, String baseUri,
String sseEndpoint, McpJsonMapper jsonMapper, McpAsyncHttpClientRequestCustomizer httpRequestCustomizer) {
String sseEndpoint, McpJsonMapper jsonMapper, McpAsyncHttpClientRequestCustomizer httpRequestCustomizer,
Consumer<Void> connectionClosedHandler) {
Assert.notNull(jsonMapper, "jsonMapper must not be null");
Assert.hasText(baseUri, "baseUri must not be empty");
Assert.hasText(sseEndpoint, "sseEndpoint must not be empty");
Expand All @@ -142,13 +147,28 @@ public class HttpClientSseClientTransport implements McpClientTransport {
this.httpClient = httpClient;
this.requestBuilder = requestBuilder;
this.httpRequestCustomizer = httpRequestCustomizer;
this.connectionClosedHandler.set(connectionClosedHandler);
}

@Override
public List<String> protocolVersions() {
return List.of(ProtocolVersions.MCP_2024_11_05);
}

@Override
public void setConnectionClosedHandler(Consumer<Void> closedHandler) {
logger.debug("Connection closed handler registered");
connectionClosedHandler.set(closedHandler);
}

private void handleConnectionClosed() {
logger.debug("Handling connection closed");
Consumer<Void> handler = this.connectionClosedHandler.get();
if (handler != null) {
handler.accept(null);
}
}

/**
* Creates a new builder for {@link HttpClientSseClientTransport}.
* @param baseUri the base URI of the MCP server
Expand Down Expand Up @@ -177,6 +197,8 @@ public static class Builder {

private Duration connectTimeout = Duration.ofSeconds(10);

private Consumer<Void> connectionClosedHandler = null;

/**
* Creates a new builder instance.
*/
Expand Down Expand Up @@ -320,14 +342,26 @@ public Builder connectTimeout(Duration connectTimeout) {
return this;
}

/**
* Set the connection closed handler.
* @param connectionClosedHandler the connection closed handler
* @return this builder
*/
public Builder connectionClosedHandler(Consumer<Void> connectionClosedHandler) {
Assert.notNull(connectionClosedHandler, "connectionClosedHandler must not be null");
this.connectionClosedHandler = connectionClosedHandler;
return this;
}

/**
* Builds a new {@link HttpClientSseClientTransport} instance.
* @return a new transport instance
*/
public HttpClientSseClientTransport build() {
HttpClient httpClient = this.clientBuilder.connectTimeout(this.connectTimeout).build();
return new HttpClientSseClientTransport(httpClient, requestBuilder, baseUri, sseEndpoint,
jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, httpRequestCustomizer);
jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, httpRequestCustomizer,
connectionClosedHandler);
}

}
Expand All @@ -352,9 +386,7 @@ public Mono<Void> connect(Function<Mono<JSONRPCMessage>, Mono<JSONRPCMessage>> h
.exceptionallyCompose(e -> {
sseSink.error(e);
return CompletableFuture.failedFuture(e);
}))
.map(responseEvent -> (ResponseSubscribers.SseResponseEvent) responseEvent)
.flatMap(responseEvent -> {
})).map(responseEvent -> (SseResponseEvent) responseEvent).flatMap(responseEvent -> {
if (isClosing) {
return Mono.empty();
}
Expand Down Expand Up @@ -388,26 +420,21 @@ else if (MESSAGE_EVENT_TYPE.equals(responseEvent.sseEvent().event())) {
sink.error(new McpTransportException("Error processing SSE event", e));
}
}
return Flux.<McpSchema.JSONRPCMessage>error(
new RuntimeException("Failed to send message: " + responseEvent));
return Flux.<JSONRPCMessage>error(new RuntimeException("Failed to send message: " + responseEvent));

})
.flatMap(jsonRpcMessage -> handler.apply(Mono.just(jsonRpcMessage)))
.onErrorComplete(t -> {
}).flatMap(jsonRpcMessage -> handler.apply(Mono.just(jsonRpcMessage))).onErrorComplete(t -> {
if (!isClosing) {
logger.warn("SSE stream observed an error", t);
sink.error(t);
}
return true;
})
.doFinally(s -> {
}).doFinally(s -> {
Disposable ref = this.sseSubscription.getAndSet(null);
if (ref != null && !ref.isDisposed()) {
ref.dispose();
}
})
.contextWrite(sink.contextView())
.subscribe();
handleConnectionClosed();
}).contextWrite(sink.contextView()).subscribe();

this.sseSubscription.set(connection);
}));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,32 +4,13 @@

package io.modelcontextprotocol.client.transport;

import java.io.IOException;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.net.http.HttpResponse.BodyHandler;
import java.time.Duration;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.CompletionException;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import java.util.function.Function;

import org.reactivestreams.Publisher;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import io.modelcontextprotocol.json.TypeRef;
import io.modelcontextprotocol.json.McpJsonMapper;

import io.modelcontextprotocol.client.transport.ResponseSubscribers.ResponseEvent;
import io.modelcontextprotocol.client.transport.customizer.McpAsyncHttpClientRequestCustomizer;
import io.modelcontextprotocol.client.transport.customizer.McpSyncHttpClientRequestCustomizer;
import io.modelcontextprotocol.client.transport.ResponseSubscribers.ResponseEvent;
import io.modelcontextprotocol.common.McpTransportContext;
import io.modelcontextprotocol.spec.ClosedMcpTransportSession;
import io.modelcontextprotocol.json.McpJsonMapper;
import io.modelcontextprotocol.json.TypeRef;
import io.modelcontextprotocol.spec.DefaultMcpTransportSession;
import io.modelcontextprotocol.spec.DefaultMcpTransportStream;
import io.modelcontextprotocol.spec.HttpHeaders;
Expand All @@ -42,13 +23,30 @@
import io.modelcontextprotocol.spec.ProtocolVersions;
import io.modelcontextprotocol.util.Assert;
import io.modelcontextprotocol.util.Utils;
import org.reactivestreams.Publisher;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.Disposable;
import reactor.core.publisher.Flux;
import reactor.core.publisher.FluxSink;
import reactor.core.publisher.Mono;
import reactor.util.function.Tuple2;
import reactor.util.function.Tuples;

import java.io.IOException;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.net.http.HttpResponse.BodyHandler;
import java.time.Duration;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.CompletionException;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import java.util.function.Function;

/**
* An implementation of the Streamable HTTP protocol as defined by the
* <code>2025-03-26</code> version of the MCP specification.
Expand Down Expand Up @@ -125,9 +123,12 @@ public class HttpClientStreamableHttpTransport implements McpClientTransport {

private final AtomicReference<Consumer<Throwable>> exceptionHandler = new AtomicReference<>();

private final AtomicReference<Consumer<Void>> connectionClosedHandler = new AtomicReference<>();

private HttpClientStreamableHttpTransport(McpJsonMapper jsonMapper, HttpClient httpClient,
HttpRequest.Builder requestBuilder, String baseUri, String endpoint, boolean resumableStreams,
boolean openConnectionOnStartup, McpAsyncHttpClientRequestCustomizer httpRequestCustomizer) {
boolean openConnectionOnStartup, McpAsyncHttpClientRequestCustomizer httpRequestCustomizer,
Consumer<Void> connectionClosedHandler) {
this.jsonMapper = jsonMapper;
this.httpClient = httpClient;
this.requestBuilder = requestBuilder;
Expand All @@ -137,6 +138,7 @@ private HttpClientStreamableHttpTransport(McpJsonMapper jsonMapper, HttpClient h
this.openConnectionOnStartup = openConnectionOnStartup;
this.activeSession.set(createTransportSession());
this.httpRequestCustomizer = httpRequestCustomizer;
this.connectionClosedHandler.set(connectionClosedHandler);
}

@Override
Expand Down Expand Up @@ -202,6 +204,12 @@ public void setExceptionHandler(Consumer<Throwable> handler) {
this.exceptionHandler.set(handler);
}

@Override
public void setConnectionClosedHandler(Consumer<Void> closedHandler) {
logger.debug("Connection closed handler registered");
this.connectionClosedHandler.set(closedHandler);
}

private void handleException(Throwable t) {
logger.debug("Handling exception for session {}", sessionIdOrPlaceholder(this.activeSession.get()), t);
if (t instanceof McpTransportSessionNotFoundException) {
Expand All @@ -215,6 +223,14 @@ private void handleException(Throwable t) {
}
}

private void handleConnectionClosed() {
logger.debug("Handling connection closed for session {}", sessionIdOrPlaceholder(this.activeSession.get()));
Consumer<Void> handler = this.connectionClosedHandler.get();
if (handler != null) {
handler.accept(null);
}
}

@Override
public Mono<Void> closeGracefully() {
return Mono.defer(() -> {
Expand Down Expand Up @@ -365,6 +381,7 @@ else if (statusCode == BAD_REQUEST) {
if (ref != null) {
transportSession.removeConnection(ref);
}
this.handleConnectionClosed();
}))
.contextWrite(ctx)
.subscribe();
Expand Down Expand Up @@ -624,6 +641,8 @@ public static class Builder {

private Duration connectTimeout = Duration.ofSeconds(10);

private Consumer<Void> connectionClosedHandler = null;

/**
* Creates a new builder with the specified base URI.
* @param baseUri the base URI of the MCP server
Expand Down Expand Up @@ -772,6 +791,17 @@ public Builder connectTimeout(Duration connectTimeout) {
return this;
}

/**
* Set the connection closed handler.
* @param connectionClosedHandler the connection closed handler
* @return this builder
*/
public Builder connectionClosedHandler(Consumer<Void> connectionClosedHandler) {
Assert.notNull(connectionClosedHandler, "connectionClosedHandler must not be null");
this.connectionClosedHandler = connectionClosedHandler;
return this;
}

/**
* Construct a fresh instance of {@link HttpClientStreamableHttpTransport} using
* the current builder configuration.
Expand All @@ -781,7 +811,7 @@ public HttpClientStreamableHttpTransport build() {
HttpClient httpClient = this.clientBuilder.connectTimeout(this.connectTimeout).build();
return new HttpClientStreamableHttpTransport(jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper,
httpClient, requestBuilder, baseUri, endpoint, resumableStreams, openConnectionOnStartup,
httpRequestCustomizer);
httpRequestCustomizer, connectionClosedHandler);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,12 @@ public interface McpClientTransport extends McpTransport {
default void setExceptionHandler(Consumer<Throwable> handler) {
}

/**
* Sets the handler for the transport closed event.
* @param closedHandler Allows reacting to transport closed event by the higher layers
*/
default void setConnectionClosedHandler(Consumer<Void> closedHandler) {

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -4,37 +4,36 @@

package io.modelcontextprotocol.client.transport;

import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.time.Duration;
import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;

import io.modelcontextprotocol.client.transport.customizer.McpAsyncHttpClientRequestCustomizer;
import io.modelcontextprotocol.client.transport.customizer.McpSyncHttpClientRequestCustomizer;
import io.modelcontextprotocol.common.McpTransportContext;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpSchema.JSONRPCRequest;

import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
import org.mockito.ArgumentCaptor;
import org.springframework.http.codec.ServerSentEvent;
import org.springframework.web.util.UriComponentsBuilder;
import org.testcontainers.containers.GenericContainer;
import org.testcontainers.containers.wait.strategy.Wait;
import reactor.core.publisher.Mono;
import reactor.core.publisher.Sinks;
import reactor.test.StepVerifier;

import org.springframework.http.codec.ServerSentEvent;
import org.springframework.web.util.UriComponentsBuilder;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.time.Duration;
import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;

import static io.modelcontextprotocol.util.McpJsonMapperUtils.JSON_MAPPER;
import static org.assertj.core.api.Assertions.assertThat;
Expand Down Expand Up @@ -78,7 +77,8 @@ static class TestHttpClientSseClientTransport extends HttpClientSseClientTranspo
public TestHttpClientSseClientTransport(final String baseUri) {
super(HttpClient.newBuilder().version(HttpClient.Version.HTTP_1_1).build(),
HttpRequest.newBuilder().header("Content-Type", "application/json"), baseUri, "/sse", JSON_MAPPER,
McpAsyncHttpClientRequestCustomizer.NOOP);
McpAsyncHttpClientRequestCustomizer.NOOP, v -> {
});
}

public int getInboundMessageCount() {
Expand Down Expand Up @@ -130,7 +130,8 @@ void testErrorOnBogusMessage() {

StepVerifier.create(transport.sendMessage(bogusMessage))
.verifyErrorMessage(
"Sending message failed with a non-OK HTTP code: 400 - Invalid message: {\"id\":\"test-id\",\"params\":{\"key\":\"value\"}}");
"Sending message failed with a non-OK HTTP code: 400 - Invalid message: {\"id\":\"test-id\","
+ "\"params\":{\"key\":\"value\"}}");
}

@Test
Expand Down Expand Up @@ -477,4 +478,14 @@ void testAsyncRequestCustomizer() {
customizedTransport.closeGracefully().block();
}

@Test
void testTransportConnectionClosedHandler() {
AtomicReference<Boolean> closedHandlerCalled = new AtomicReference<>(false);
transport.setConnectionClosedHandler(v -> closedHandlerCalled.set(true));
// transport close simulate the behavior of disconnection
transport.closeGracefully().block();
// Verify the closed handler was called
Assertions.assertTrue(closedHandlerCalled.get());
}

}
Loading