Skip to content

HTTP/2 improvements for CVE-2023-36478 #9749

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
May 27, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.eclipse.jetty.http2.HTTP2Connection;
import org.eclipse.jetty.http2.ISession;
import org.eclipse.jetty.http2.api.Session;
import org.eclipse.jetty.http2.frames.Frame;
import org.eclipse.jetty.http2.frames.PrefaceFrame;
import org.eclipse.jetty.http2.frames.SettingsFrame;
import org.eclipse.jetty.http2.frames.WindowUpdateFrame;
Expand Down Expand Up @@ -58,23 +59,25 @@ public Connection newConnection(EndPoint endPoint, Map<String, Object> context)

Generator generator = new Generator(byteBufferPool, client.getMaxDynamicTableSize(), client.getMaxHeaderBlockFragment());
FlowControlStrategy flowControl = client.getFlowControlStrategyFactory().newFlowControlStrategy();
HTTP2ClientSession session = new HTTP2ClientSession(scheduler, endPoint, generator, listener, flowControl);

Parser parser = new Parser(byteBufferPool, 4096, 8192);
parser.setMaxFrameLength(client.getMaxFrameLength());
parser.setMaxSettingsKeys(client.getMaxSettingsKeys());

HTTP2ClientSession session = new HTTP2ClientSession(scheduler, endPoint, parser, generator, listener, flowControl);
session.setMaxRemoteStreams(client.getMaxConcurrentPushedStreams());
long streamIdleTimeout = client.getStreamIdleTimeout();
if (streamIdleTimeout > 0)
session.setStreamIdleTimeout(streamIdleTimeout);

Parser parser = new Parser(byteBufferPool, session, 4096, 8192);
parser.setMaxFrameLength(client.getMaxFrameLength());
parser.setMaxSettingsKeys(client.getMaxSettingsKeys());

RetainableByteBufferPool retainableByteBufferPool = byteBufferPool.asRetainableByteBufferPool();

HTTP2ClientConnection connection = new HTTP2ClientConnection(client, retainableByteBufferPool, executor, endPoint,
parser, session, client.getInputBufferSize(), promise, listener);
session, client.getInputBufferSize(), promise, listener);
connection.setUseInputDirectByteBuffers(client.isUseInputDirectByteBuffers());
connection.setUseOutputDirectByteBuffers(client.isUseOutputDirectByteBuffers());
connection.addEventListener(connectionListener);
parser.init(connection);

return customize(connection, context);
}

Expand All @@ -84,9 +87,9 @@ private static class HTTP2ClientConnection extends HTTP2Connection implements Ca
private final Promise<Session> promise;
private final Session.Listener listener;

private HTTP2ClientConnection(HTTP2Client client, RetainableByteBufferPool retainableByteBufferPool, Executor executor, EndPoint endpoint, Parser parser, ISession session, int bufferSize, Promise<Session> promise, Session.Listener listener)
private HTTP2ClientConnection(HTTP2Client client, RetainableByteBufferPool retainableByteBufferPool, Executor executor, EndPoint endpoint, HTTP2ClientSession session, int bufferSize, Promise<Session> promise, Session.Listener listener)
{
super(retainableByteBufferPool, executor, endpoint, parser, session, bufferSize);
super(retainableByteBufferPool, executor, endpoint, session, bufferSize);
this.client = client;
this.promise = promise;
this.listener = listener;
Expand All @@ -98,12 +101,42 @@ public void onOpen()
Map<Integer, Integer> settings = listener.onPreface(getSession());
if (settings == null)
settings = new HashMap<>();
settings.computeIfAbsent(SettingsFrame.INITIAL_WINDOW_SIZE, k -> client.getInitialStreamRecvWindow());
settings.computeIfAbsent(SettingsFrame.MAX_CONCURRENT_STREAMS, k -> client.getMaxConcurrentPushedStreams());

Integer maxFrameLength = settings.get(SettingsFrame.MAX_FRAME_SIZE);
if (maxFrameLength != null)
getParser().setMaxFrameLength(maxFrameLength);
// Here we want to populate any settings to send to the server
// that have a different default than what prescribed by the RFC.
// Changing the configuration is done when the SETTINGS is sent.

settings.compute(SettingsFrame.HEADER_TABLE_SIZE, (k, v) ->
{
if (v == null)
{
v = client.getMaxDynamicTableSize();
if (v == 4096)
v = null;
}
return v;
});
settings.computeIfAbsent(SettingsFrame.MAX_CONCURRENT_STREAMS, k -> client.getMaxConcurrentPushedStreams());
settings.compute(SettingsFrame.INITIAL_WINDOW_SIZE, (k, v) ->
{
if (v == null)
{
v = client.getInitialStreamRecvWindow();
if (v == FlowControlStrategy.DEFAULT_WINDOW_SIZE)
v = null;
}
return v;
});
settings.compute(SettingsFrame.MAX_FRAME_SIZE, (k, v) ->
{
if (v == null)
{
v = client.getMaxFrameLength();
if (v == Frame.DEFAULT_MAX_LENGTH)
v = null;
}
return v;
});

PrefaceFrame prefaceFrame = new PrefaceFrame();
SettingsFrame settingsFrame = new SettingsFrame(settings, false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

package org.eclipse.jetty.http2.client;

import java.util.Map;

import org.eclipse.jetty.http.MetaData;
import org.eclipse.jetty.http2.CloseState;
import org.eclipse.jetty.http2.ErrorCode;
Expand All @@ -23,7 +25,9 @@
import org.eclipse.jetty.http2.api.Stream;
import org.eclipse.jetty.http2.frames.HeadersFrame;
import org.eclipse.jetty.http2.frames.PushPromiseFrame;
import org.eclipse.jetty.http2.frames.SettingsFrame;
import org.eclipse.jetty.http2.generator.Generator;
import org.eclipse.jetty.http2.parser.Parser;
import org.eclipse.jetty.io.EndPoint;
import org.eclipse.jetty.util.Callback;
import org.eclipse.jetty.util.thread.Scheduler;
Expand All @@ -34,9 +38,9 @@ public class HTTP2ClientSession extends HTTP2Session
{
private static final Logger LOG = LoggerFactory.getLogger(HTTP2ClientSession.class);

public HTTP2ClientSession(Scheduler scheduler, EndPoint endPoint, Generator generator, Session.Listener listener, FlowControlStrategy flowControl)
public HTTP2ClientSession(Scheduler scheduler, EndPoint endPoint, Parser parser, Generator generator, Session.Listener listener, FlowControlStrategy flowControl)
{
super(scheduler, endPoint, generator, listener, flowControl, 1);
super(scheduler, endPoint, parser, generator, listener, flowControl, 1);
}

@Override
Expand Down Expand Up @@ -87,12 +91,30 @@ public void onHeaders(HeadersFrame frame)
}
}

@Override
public void onSettings(SettingsFrame frame)
{
Map<Integer, Integer> settings = frame.getSettings();
Integer value = settings.get(SettingsFrame.ENABLE_PUSH);
// SPEC: servers can only send ENABLE_PUSH=0.
if (value != null && value != 0)
onConnectionFailure(ErrorCode.PROTOCOL_ERROR.code, "invalid_settings_frame");
else
super.onSettings(frame);
}

@Override
public void onPushPromise(PushPromiseFrame frame)
{
if (LOG.isDebugEnabled())
LOG.debug("Received {}", frame);

if (!isPushEnabled())
{
onConnectionFailure(ErrorCode.PROTOCOL_ERROR.code, "unexpected_push_promise_frame");
return;
}

int streamId = frame.getStreamId();
int pushStreamId = frame.getPromisedStreamId();
IStream stream = getStream(streamId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.eclipse.jetty.http.HttpStatus;
import org.eclipse.jetty.http.HttpVersion;
import org.eclipse.jetty.http.MetaData;
import org.eclipse.jetty.http2.AbstractFlowControlStrategy;
import org.eclipse.jetty.http2.BufferingFlowControlStrategy;
import org.eclipse.jetty.http2.ErrorCode;
import org.eclipse.jetty.http2.FlowControlStrategy;
Expand Down Expand Up @@ -64,6 +65,7 @@
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;

import static org.awaitility.Awaitility.await;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
Expand Down Expand Up @@ -204,11 +206,9 @@ public Stream.Listener onNewStream(Stream stream, HeadersFrame frame)
SettingsFrame frame = new SettingsFrame(settings, false);
FutureCallback callback = new FutureCallback();
clientSession.settings(frame, callback);
callback.get(5, TimeUnit.SECONDS);

await().atMost(5, TimeUnit.SECONDS).until(() -> clientStream1.getRecvWindow() == 0);
assertEquals(FlowControlStrategy.DEFAULT_WINDOW_SIZE, clientStream1.getSendWindow());
assertEquals(0, clientStream1.getRecvWindow());
settingsLatch.await(5, TimeUnit.SECONDS);

// Now create a new stream, it must pick up the new value.
MetaData.Request request2 = newRequest("POST", HttpFields.EMPTY);
Expand Down Expand Up @@ -343,6 +343,11 @@ public Stream.Listener onNewStream(Stream stream, HeadersFrame requestFrame)
completable.thenRun(settingsLatch::countDown);

assertTrue(settingsLatch.await(5, TimeUnit.SECONDS));
await().atMost(5, TimeUnit.SECONDS).until(() ->
{
AbstractFlowControlStrategy flow = (AbstractFlowControlStrategy)((HTTP2Session)session).getFlowControlStrategy();
return flow.getInitialStreamRecvWindow() == windowSize;
});

CountDownLatch dataLatch = new CountDownLatch(1);
Exchanger<Callback> exchanger = new Exchanger<>();
Expand Down Expand Up @@ -403,13 +408,14 @@ public void testClientFlowControlOneBigWrite() throws Exception
{
int windowSize = 1536;
Exchanger<Callback> exchanger = new Exchanger<>();
CountDownLatch settingsLatch = new CountDownLatch(1);
CountDownLatch dataLatch = new CountDownLatch(1);
AtomicReference<HTTP2Session> serverSessionRef = new AtomicReference<>();
start(new ServerSessionListener.Adapter()
{
@Override
public Map<Integer, Integer> onPreface(Session session)
{
serverSessionRef.set((HTTP2Session)session);
Map<Integer, Integer> settings = new HashMap<>();
settings.put(SettingsFrame.INITIAL_WINDOW_SIZE, windowSize);
return settings;
Expand Down Expand Up @@ -458,21 +464,18 @@ else if (dataFrames == 3 || dataFrames == 4 || dataFrames == 5)
}
});

Session session = newClient(new Session.Listener.Adapter()
Session clientSession = newClient(new Session.Listener.Adapter());

await().atMost(5, TimeUnit.SECONDS).until(() ->
{
@Override
public void onSettings(Session session, SettingsFrame frame)
{
settingsLatch.countDown();
}
AbstractFlowControlStrategy flow = (AbstractFlowControlStrategy)serverSessionRef.get().getFlowControlStrategy();
return flow.getInitialStreamRecvWindow() == windowSize;
});

assertTrue(settingsLatch.await(5, TimeUnit.SECONDS));

MetaData.Request metaData = newRequest("GET", HttpFields.EMPTY);
HeadersFrame requestFrame = new HeadersFrame(metaData, null, false);
FuturePromise<Stream> streamPromise = new FuturePromise<>();
session.newStream(requestFrame, streamPromise, null);
clientSession.newStream(requestFrame, streamPromise, null);
Stream stream = streamPromise.get(5, TimeUnit.SECONDS);

int length = 5 * windowSize;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.eclipse.jetty.http2.api.Stream;
import org.eclipse.jetty.http2.api.server.ServerSessionListener;
import org.eclipse.jetty.http2.frames.DataFrame;
import org.eclipse.jetty.http2.frames.FrameType;
import org.eclipse.jetty.http2.frames.GoAwayFrame;
import org.eclipse.jetty.http2.frames.HeadersFrame;
import org.eclipse.jetty.http2.frames.ResetFrame;
Expand Down Expand Up @@ -1103,4 +1104,70 @@ public void onReset(Stream stream, ResetFrame frame)
Assertions.assertFalse(((HTTP2Session)serverSessionRef.get()).getEndPoint().isOpen());
Assertions.assertFalse(((HTTP2Session)clientSession).getEndPoint().isOpen());
}

@Test
public void testGoAwayNonZeroStreamId() throws Exception
{
CountDownLatch serverGoAwayLatch = new CountDownLatch(1);
CountDownLatch serverFailureLatch = new CountDownLatch(1);
CountDownLatch serverCloseLatch = new CountDownLatch(1);
start(new ServerSessionListener.Adapter()
{
@Override
public void onGoAway(Session session, GoAwayFrame frame)
{
serverGoAwayLatch.countDown();
}

@Override
public void onFailure(Session session, Throwable failure)
{
serverFailureLatch.countDown();
}

@Override
public void onClose(Session session, GoAwayFrame frame)
{
serverCloseLatch.countDown();
}
});

CountDownLatch clientGoAwayLatch = new CountDownLatch(1);
CountDownLatch clientCloseLatch = new CountDownLatch(1);
Session clientSession = newClient(new Session.Listener.Adapter()
{
@Override
public void onGoAway(Session session, GoAwayFrame frame)
{
clientGoAwayLatch.countDown();
}

@Override
public void onClose(Session session, GoAwayFrame frame)
{
clientCloseLatch.countDown();
}
});

// Wait until the client has finished the previous writes.
Thread.sleep(1000);
// Write an invalid GOAWAY frame.
ByteBuffer byteBuffer = ByteBuffer.allocate(17)
.put((byte)0)
.put((byte)0)
.put((byte)8)
.put((byte)FrameType.GO_AWAY.getType())
.put((byte)0)
.putInt(1) // Non-Zero Stream ID
.putInt(0)
.putInt(ErrorCode.PROTOCOL_ERROR.code)
.flip();
((HTTP2Session)clientSession).getEndPoint().write(Callback.NOOP, byteBuffer);

Assertions.assertFalse(serverGoAwayLatch.await(1, TimeUnit.SECONDS));
Assertions.assertTrue(serverFailureLatch.await(5, TimeUnit.SECONDS));
Assertions.assertTrue(serverCloseLatch.await(5, TimeUnit.SECONDS));
Assertions.assertTrue(clientGoAwayLatch.await(5, TimeUnit.SECONDS));
Assertions.assertTrue(clientCloseLatch.await(5, TimeUnit.SECONDS));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,7 @@
import org.eclipse.jetty.http2.frames.ResetFrame;
import org.eclipse.jetty.http2.frames.SettingsFrame;
import org.eclipse.jetty.http2.hpack.HpackException;
import org.eclipse.jetty.http2.parser.RateControl;
import org.eclipse.jetty.http2.parser.ServerParser;
import org.eclipse.jetty.http2.server.RawHTTP2ServerConnectionFactory;
import org.eclipse.jetty.server.Connector;
import org.eclipse.jetty.server.HttpConfiguration;
import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.util.Callback;
Expand Down Expand Up @@ -737,6 +734,7 @@ public void onFailure(Session session, Throwable failure)
@Test
public void testGoAwayRespondedWithGoAway() throws Exception
{
CountDownLatch goAwayLatch = new CountDownLatch(1);
ServerSessionListener.Adapter serverListener = new ServerSessionListener.Adapter()
{
@Override
Expand All @@ -748,24 +746,14 @@ public Stream.Listener onNewStream(Stream stream, HeadersFrame frame)
stream.getSession().close(ErrorCode.NO_ERROR.code, null, Callback.NOOP);
return null;
}
};
CountDownLatch goAwayLatch = new CountDownLatch(1);
RawHTTP2ServerConnectionFactory connectionFactory = new RawHTTP2ServerConnectionFactory(new HttpConfiguration(), serverListener)
{

@Override
protected ServerParser newServerParser(Connector connector, ServerParser.Listener listener, RateControl rateControl)
public void onGoAway(Session session, GoAwayFrame frame)
{
return super.newServerParser(connector, new ServerParser.Listener.Wrapper(listener)
{
@Override
public void onGoAway(GoAwayFrame frame)
{
super.onGoAway(frame);
goAwayLatch.countDown();
}
}, rateControl);
goAwayLatch.countDown();
}
};
RawHTTP2ServerConnectionFactory connectionFactory = new RawHTTP2ServerConnectionFactory(new HttpConfiguration(), serverListener);
prepareServer(connectionFactory);
server.start();

Expand Down
Loading