Skip to content

Commit

Permalink
Merge pull request #12186 from jetty/jetty-12.1.x-servletUpgrade
Browse files Browse the repository at this point in the history
implement servlet upgrade for ee10 and ee11
  • Loading branch information
lachlan-roberts authored Sep 4, 2024
2 parents 165327a + f795fb1 commit 8c5d5e8
Show file tree
Hide file tree
Showing 15 changed files with 927 additions and 522 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1865,7 +1865,7 @@ public void onError(Throwable t)
}
});
// Close the parser to cause the issue.
org.eclipse.jetty.server.HttpConnection.getCurrentConnection().getParser().close();
org.eclipse.jetty.server.internal.HttpConnection.getCurrentConnection().getParser().close();
}
});
server.start();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -455,12 +455,18 @@ else if (status == HttpStatus.NO_CONTENT_204 || status == HttpStatus.NOT_MODIFIE
}
}

public void servletUpgrade()
public void startTunnel()
{
_noContentResponse = false;
_state = State.COMMITTED;
}

@Deprecated(since = "12.1.0", forRemoval = true)
public void servletUpgrade()
{
startTunnel();
}

private void prepareChunk(ByteBuffer chunk, int remaining)
{
// if we need CRLF add this to header
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2016,13 +2016,19 @@ public void reset()
_headerComplete = false;
}

public void servletUpgrade()
public void startTunnel()
{
setState(State.CONTENT);
_endOfContent = EndOfContent.UNKNOWN_CONTENT;
setState(State.EOF_CONTENT);
_endOfContent = EndOfContent.EOF_CONTENT;
_contentLength = -1;
}

@Deprecated(since = "12.1.0", forRemoval = true)
public void servletUpgrade()
{
startTunnel();
}

protected void setState(State state)
{
if (debugEnabled)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,16 @@ interface UpgradeTo
void onUpgradeTo(ByteBuffer buffer);
}

/**
* <p>Start a tunnel over the current connection without replacing the connection.</p>
* <p>This can be used for upgrade within a connection, but it is not really an upgrade for this connection
* as the connection remains and just tunnels data to/from its endpoint.</p>
*/
interface Tunnel
{
void startTunnel();
}

/**
* <p>A Listener for connection events.</p>
* <p>Listeners can be added to a {@link Connection} to get open and close events.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@
/**
* <p>A {@link Connection} that handles the HTTP protocol.</p>
*/
public class HttpConnection extends AbstractMetaDataConnection implements Runnable, Connection.UpgradeFrom, Connection.UpgradeTo, ConnectionMetaData
public class HttpConnection extends AbstractMetaDataConnection implements Runnable, Connection.UpgradeFrom, Connection.UpgradeTo, Connection.Tunnel, ConnectionMetaData
{
private static final Logger LOG = LoggerFactory.getLogger(HttpConnection.class);
private static final HttpField PREAMBLE_UPGRADE_H2C = new HttpField(HttpHeader.UPGRADE, "h2c");
Expand Down Expand Up @@ -336,6 +336,13 @@ public void onUpgradeTo(ByteBuffer buffer)
BufferUtil.append(getRequestBuffer(), buffer);
}

@Override
public void startTunnel()
{
getParser().startTunnel();
getGenerator().startTunnel();
}

void releaseRequestBuffer()
{
if (_retainableByteBuffer != null && _retainableByteBuffer.isEmpty())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,18 @@

import jakarta.servlet.AsyncContext;
import jakarta.servlet.DispatcherType;
import jakarta.servlet.ReadListener;
import jakarta.servlet.RequestDispatcher;
import jakarta.servlet.ServletConnection;
import jakarta.servlet.ServletContext;
import jakarta.servlet.ServletException;
import jakarta.servlet.ServletInputStream;
import jakarta.servlet.ServletOutputStream;
import jakarta.servlet.ServletRequest;
import jakarta.servlet.ServletRequestAttributeEvent;
import jakarta.servlet.ServletRequestAttributeListener;
import jakarta.servlet.ServletResponse;
import jakarta.servlet.WriteListener;
import jakarta.servlet.http.Cookie;
import jakarta.servlet.http.HttpServletMapping;
import jakarta.servlet.http.HttpServletRequest;
Expand All @@ -58,7 +61,10 @@
import jakarta.servlet.http.HttpUpgradeHandler;
import jakarta.servlet.http.Part;
import jakarta.servlet.http.PushBuilder;
import jakarta.servlet.http.WebConnection;
import org.eclipse.jetty.ee10.servlet.ServletContextHandler.ServletRequestInfo;
import org.eclipse.jetty.ee10.servlet.util.ServletInputStreamWrapper;
import org.eclipse.jetty.ee10.servlet.util.ServletOutputStreamWrapper;
import org.eclipse.jetty.http.BadMessageException;
import org.eclipse.jetty.http.CookieCompliance;
import org.eclipse.jetty.http.HttpCookie;
Expand All @@ -72,6 +78,7 @@
import org.eclipse.jetty.http.MimeTypes;
import org.eclipse.jetty.http.SetCookieParser;
import org.eclipse.jetty.http.pathmap.MatchedResource;
import org.eclipse.jetty.io.Connection;
import org.eclipse.jetty.io.QuietException;
import org.eclipse.jetty.io.RuntimeIOException;
import org.eclipse.jetty.security.AuthenticationState;
Expand Down Expand Up @@ -737,8 +744,255 @@ public Part getPart(String name) throws IOException, ServletException
@Override
public <T extends HttpUpgradeHandler> T upgrade(Class<T> handlerClass) throws IOException, ServletException
{
// Not implemented. Throw ServletException as per spec.
throw new ServletException("Not implemented");
Response response = _servletContextRequest.getServletContextResponse();
if (response.getStatus() != HttpStatus.SWITCHING_PROTOCOLS_101)
throw new IllegalStateException("Response status should be 101");
if (response.getHeaders().get("Upgrade") == null)
throw new IllegalStateException("Missing Upgrade header");
if (!"Upgrade".equalsIgnoreCase(response.getHeaders().get("Connection")))
throw new IllegalStateException("Invalid Connection header");
if (response.isCommitted())
throw new IllegalStateException("Cannot upgrade committed response");
if (_servletChannel.getConnectionMetaData().getHttpVersion() != HttpVersion.HTTP_1_1)
throw new IllegalStateException("Only requests over HTTP/1.1 can be upgraded");

CompletableFuture<Void> outputStreamComplete = new CompletableFuture<>();
CompletableFuture<Void> inputStreamComplete = new CompletableFuture<>();
ServletOutputStream outputStream = new ServletOutputStreamWrapper(_servletContextRequest.getHttpOutput())
{
@Override
public void write(int b) throws IOException
{
try
{
super.write(b);
}
catch (Throwable t)
{
outputStreamComplete.completeExceptionally(t);
throw t;
}
}

@Override
public void write(byte[] b) throws IOException
{
try
{
super.write(b);
}
catch (Throwable t)
{
outputStreamComplete.completeExceptionally(t);
throw t;
}
}

@Override
public void write(byte[] b, int off, int len) throws IOException
{
try
{
super.write(b, off, len);
}
catch (Throwable t)
{
outputStreamComplete.completeExceptionally(t);
throw t;
}
}

@Override
public void close() throws IOException
{
try
{
super.close();
outputStreamComplete.complete(null);
}
catch (Throwable t)
{
outputStreamComplete.completeExceptionally(t);
throw t;
}
}

@Override
public void setWriteListener(WriteListener writeListener)
{
super.setWriteListener(new WriteListener()
{
@Override
public void onWritePossible() throws IOException
{
writeListener.onWritePossible();
}

@Override
public void onError(Throwable t)
{
writeListener.onError(t);
outputStreamComplete.completeExceptionally(t);
}
});
}
};
ServletInputStream inputStream = new ServletInputStreamWrapper(_servletContextRequest.getHttpInput())
{
@Override
public int read() throws IOException
{
try
{
int read = super.read();
if (read == -1)
inputStreamComplete.complete(null);
return read;
}
catch (Throwable t)
{
inputStreamComplete.completeExceptionally(t);
throw t;
}
}

@Override
public int read(byte[] b) throws IOException
{
try
{
int read = super.read(b);
if (read == -1)
inputStreamComplete.complete(null);
return read;
}
catch (Throwable t)
{
inputStreamComplete.completeExceptionally(t);
throw t;
}
}

@Override
public int read(byte[] b, int off, int len) throws IOException
{
try
{
int read = super.read(b, off, len);
if (read == -1)
inputStreamComplete.complete(null);
return read;
}
catch (Throwable t)
{
inputStreamComplete.completeExceptionally(t);
throw t;
}
}

@Override
public void close() throws IOException
{
try
{
super.close();
inputStreamComplete.complete(null);
}
catch (Throwable t)
{
inputStreamComplete.completeExceptionally(t);
throw t;
}
}

@Override
public void setReadListener(ReadListener readListener)
{
super.setReadListener(new ReadListener()
{
@Override
public void onDataAvailable() throws IOException
{
readListener.onDataAvailable();
}

@Override
public void onAllDataRead() throws IOException
{
try
{
readListener.onAllDataRead();
inputStreamComplete.complete(null);
}
catch (Throwable t)
{
inputStreamComplete.completeExceptionally(t);
throw t;
}
}

@Override
public void onError(Throwable t)
{
readListener.onError(t);
inputStreamComplete.completeExceptionally(t);
}
});
}
};

T upgradeHandler;
try
{
upgradeHandler = handlerClass.getDeclaredConstructor().newInstance();
}
catch (Exception e)
{
throw new ServletException("Unable to instantiate handler class", e);
}

Connection connection = _servletContextRequest.getConnectionMetaData().getConnection();
if (connection instanceof Connection.Tunnel upgradeableConnection)
{
outputStream.flush(); // commit the 101 response
upgradeableConnection.startTunnel();
}
else
{
LOG.warn("Unexpected connection type {}", connection);
throw new IllegalStateException();
}
AsyncContext asyncContext = forceStartAsync(); // force the servlet in async mode
CompletableFuture.allOf(inputStreamComplete, outputStreamComplete).whenComplete((result, failure) ->
{
upgradeHandler.destroy();
asyncContext.complete();
});

WebConnection webConnection = new WebConnection()
{
@Override
public void close() throws Exception
{
IO.close(inputStream);
IO.close(outputStream);
}

@Override
public ServletInputStream getInputStream()
{
return inputStream;
}

@Override
public ServletOutputStream getOutputStream()
{
return outputStream;
}
};

upgradeHandler.init(webConnection);
return upgradeHandler;
}

@Override
Expand Down Expand Up @@ -1374,6 +1628,11 @@ public AsyncContext startAsync() throws IllegalStateException
{
if (!isAsyncSupported())
throw new IllegalStateException("Async Not Supported");
return forceStartAsync();
}

private AsyncContext forceStartAsync()
{
ServletChannelState state = getServletRequestInfo().getState();
if (_async == null)
_async = new AsyncContextState(state);
Expand Down
Loading

0 comments on commit 8c5d5e8

Please sign in to comment.