From 42c60065368724b841bfcf8a39f46fbbb1708f9d Mon Sep 17 00:00:00 2001 From: Arturo Bernal Date: Tue, 25 Mar 2025 14:21:22 +0100 Subject: [PATCH] Simplify ProtocolSwitchStrategy by Leveraging ProtocolVersionParser Unify HTTP and TLS token parsing in the Upgrade header by replacing custom version parsing with ProtocolVersionParser. This change removes redundant code and ensures that only supported protocols (HTTP/ and TLS tokens) are accepted, while all other upgrade protocols are rejected as unsupported. --- .../http/impl/ProtocolSwitchStrategy.java | 120 +++++++++++++++--- .../http/impl/TestProtocolSwitchStrategy.java | 119 ++++++++++++++++- 2 files changed, 219 insertions(+), 20 deletions(-) diff --git a/httpclient5/src/main/java/org/apache/hc/client5/http/impl/ProtocolSwitchStrategy.java b/httpclient5/src/main/java/org/apache/hc/client5/http/impl/ProtocolSwitchStrategy.java index 6e5133fe74..061690eb95 100644 --- a/httpclient5/src/main/java/org/apache/hc/client5/http/impl/ProtocolSwitchStrategy.java +++ b/httpclient5/src/main/java/org/apache/hc/client5/http/impl/ProtocolSwitchStrategy.java @@ -27,15 +27,22 @@ package org.apache.hc.client5.http.impl; import java.util.Iterator; +import java.util.concurrent.atomic.AtomicReference; import org.apache.hc.core5.annotation.Internal; +import org.apache.hc.core5.http.FormattedHeader; +import org.apache.hc.core5.http.Header; import org.apache.hc.core5.http.HttpHeaders; import org.apache.hc.core5.http.HttpMessage; +import org.apache.hc.core5.http.HttpVersion; import org.apache.hc.core5.http.ParseException; import org.apache.hc.core5.http.ProtocolException; import org.apache.hc.core5.http.ProtocolVersion; -import org.apache.hc.core5.http.message.MessageSupport; +import org.apache.hc.core5.http.ProtocolVersionParser; import org.apache.hc.core5.http.ssl.TLS; +import org.apache.hc.core5.util.Args; +import org.apache.hc.core5.util.CharArrayBuffer; +import org.apache.hc.core5.util.Tokenizer; /** * Protocol switch handler. @@ -45,31 +52,106 @@ @Internal public final class ProtocolSwitchStrategy { - enum ProtocolSwitch { FAILURE, TLS } + private static final ProtocolVersionParser PROTOCOL_VERSION_PARSER = ProtocolVersionParser.INSTANCE; + + private static final Tokenizer TOKENIZER = Tokenizer.INSTANCE; + + private static final Tokenizer.Delimiter UPGRADE_TOKEN_DELIMITER = Tokenizer.delimiters(','); + + @FunctionalInterface + private interface HeaderConsumer { + void accept(CharSequence buffer, Tokenizer.Cursor cursor) throws ProtocolException; + } public ProtocolVersion switchProtocol(final HttpMessage response) throws ProtocolException { - final Iterator it = MessageSupport.iterateTokens(response, HttpHeaders.UPGRADE); + final AtomicReference tlsUpgrade = new AtomicReference<>(); - ProtocolVersion tlsUpgrade = null; - while (it.hasNext()) { - final String token = it.next(); - if (token.startsWith("TLS")) { - // TODO: Improve handling of HTTP protocol token once HttpVersion has a #parse method - try { - tlsUpgrade = token.length() == 3 ? TLS.V_1_2.getVersion() : TLS.parse(token.replace("TLS/", "TLSv")); - } catch (final ParseException ex) { - throw new ProtocolException("Invalid protocol: " + token); + parseHeaders(response, HttpHeaders.UPGRADE, (buffer, cursor) -> { + while (!cursor.atEnd()) { + TOKENIZER.skipWhiteSpace(buffer, cursor); + if (cursor.atEnd()) { + break; + } + final int tokenStart = cursor.getPos(); + TOKENIZER.parseToken(buffer, cursor, UPGRADE_TOKEN_DELIMITER); + final int tokenEnd = cursor.getPos(); + if (tokenStart < tokenEnd) { + final ProtocolVersion version = parseProtocolToken(buffer, tokenStart, tokenEnd); + if (version != null && "TLS".equalsIgnoreCase(version.getProtocol())) { + tlsUpgrade.set(version); + } } - } else if (token.equals("HTTP/1.1")) { - // TODO: Improve handling of HTTP protocol token once HttpVersion has a #parse method + if (!cursor.atEnd()) { + cursor.updatePos(cursor.getPos() + 1); + } + } + }); + + final ProtocolVersion result = tlsUpgrade.get(); + if (result != null) { + return result; + } else { + throw new ProtocolException("Invalid protocol switch response: no TLS version found"); + } + } + + private ProtocolVersion parseProtocolToken(final CharSequence buffer, final int start, final int end) + throws ProtocolException { + if (start >= end) { + return null; + } + + if (end - start == 3) { + final char c0 = buffer.charAt(start); + final char c1 = buffer.charAt(start + 1); + final char c2 = buffer.charAt(start + 2); + if ((c0 == 'T' || c0 == 't') && + (c1 == 'L' || c1 == 'l') && + (c2 == 'S' || c2 == 's')) { + return TLS.V_1_2.getVersion(); + } + } + + try { + final Tokenizer.Cursor cursor = new Tokenizer.Cursor(start, end); + final ProtocolVersion version = PROTOCOL_VERSION_PARSER.parse(buffer, cursor, null); + + if ("TLS".equalsIgnoreCase(version.getProtocol())) { + return version; + } else if (version.equals(HttpVersion.HTTP_1_1)) { + return null; } else { - throw new ProtocolException("Unsupported protocol: " + token); + throw new ProtocolException("Unsupported protocol or HTTP version: " + buffer.subSequence(start, end)); } + } catch (final ParseException ex) { + throw new ProtocolException("Invalid protocol: " + buffer.subSequence(start, end), ex); } - if (tlsUpgrade == null) { - throw new ProtocolException("Invalid protocol switch response"); + } + + private void parseHeaders(final HttpMessage message, final String name, final HeaderConsumer consumer) + throws ProtocolException { + Args.notNull(message, "Message headers"); + Args.notBlank(name, "Header name"); + final Iterator
it = message.headerIterator(name); + while (it.hasNext()) { + parseHeader(it.next(), consumer); } - return tlsUpgrade; } -} + private void parseHeader(final Header header, final HeaderConsumer consumer) throws ProtocolException { + Args.notNull(header, "Header"); + if (header instanceof FormattedHeader) { + final CharArrayBuffer buf = ((FormattedHeader) header).getBuffer(); + final Tokenizer.Cursor cursor = new Tokenizer.Cursor(0, buf.length()); + cursor.updatePos(((FormattedHeader) header).getValuePos()); + consumer.accept(buf, cursor); + } else { + final String value = header.getValue(); + if (value == null) { + return; + } + final Tokenizer.Cursor cursor = new Tokenizer.Cursor(0, value.length()); + consumer.accept(value, cursor); + } + } +} \ No newline at end of file diff --git a/httpclient5/src/test/java/org/apache/hc/client5/http/impl/TestProtocolSwitchStrategy.java b/httpclient5/src/test/java/org/apache/hc/client5/http/impl/TestProtocolSwitchStrategy.java index 9c8593a44b..2c7c322818 100644 --- a/httpclient5/src/test/java/org/apache/hc/client5/http/impl/TestProtocolSwitchStrategy.java +++ b/httpclient5/src/test/java/org/apache/hc/client5/http/impl/TestProtocolSwitchStrategy.java @@ -30,6 +30,7 @@ import org.apache.hc.core5.http.HttpResponse; import org.apache.hc.core5.http.HttpStatus; import org.apache.hc.core5.http.ProtocolException; +import org.apache.hc.core5.http.ProtocolVersion; import org.apache.hc.core5.http.message.BasicHttpResponse; import org.apache.hc.core5.http.ssl.TLS; import org.junit.jupiter.api.Assertions; @@ -37,7 +38,7 @@ import org.junit.jupiter.api.Test; /** - * Simple tests for {@link DefaultAuthenticationStrategy}. + * Simple tests for {@link ProtocolSwitchStrategy}. */ class TestProtocolSwitchStrategy { @@ -95,4 +96,120 @@ void testSwitchInvalid() { Assertions.assertThrows(ProtocolException.class, () -> switchStrategy.switchProtocol(response3)); } + @Test + void testNullToken() throws ProtocolException { + final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS); + response.addHeader(HttpHeaders.UPGRADE, "TLS,"); + response.addHeader(HttpHeaders.UPGRADE, null); + Assertions.assertEquals(TLS.V_1_2.getVersion(), switchStrategy.switchProtocol(response)); + } + + @Test + void testWhitespaceOnlyToken() throws ProtocolException { + final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS); + response.addHeader(HttpHeaders.UPGRADE, " , TLS"); + Assertions.assertEquals(TLS.V_1_2.getVersion(), switchStrategy.switchProtocol(response)); + } + + @Test + void testUnsupportedTlsVersion() throws Exception { + final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS); + response.addHeader(HttpHeaders.UPGRADE, "TLS/1.4"); + Assertions.assertEquals(new ProtocolVersion("TLS", 1, 4), switchStrategy.switchProtocol(response)); + } + + @Test + void testUnsupportedTlsMajorVersion() throws Exception { + final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS); + response.addHeader(HttpHeaders.UPGRADE, "TLS/2.0"); + Assertions.assertEquals(new ProtocolVersion("TLS", 2, 0), switchStrategy.switchProtocol(response)); + } + + @Test + void testUnsupportedHttpVersion() { + final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS); + response.addHeader(HttpHeaders.UPGRADE, "HTTP/2.0"); + Assertions.assertThrows(ProtocolException.class, () -> switchStrategy.switchProtocol(response), + "Unsupported HTTP version: HTTP/2.0"); + } + + @Test + void testInvalidTlsFormat() { + final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS); + response.addHeader(HttpHeaders.UPGRADE, "TLS/abc"); + Assertions.assertThrows(ProtocolException.class, () -> switchStrategy.switchProtocol(response), + "Invalid protocol: TLS/abc"); + } + + @Test + void testHttp11Only() { + final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS); + response.addHeader(HttpHeaders.UPGRADE, "HTTP/1.1"); + Assertions.assertThrows(ProtocolException.class, () -> switchStrategy.switchProtocol(response), + "Invalid protocol switch response: no TLS version found"); + } + + @Test + void testSwitchToTlsValid_TLS_1_2() throws Exception { + final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS); + response.addHeader(HttpHeaders.UPGRADE, "TLS/1.2"); + final ProtocolVersion result = switchStrategy.switchProtocol(response); + Assertions.assertEquals(TLS.V_1_2.getVersion(), result); + } + + @Test + void testSwitchToTlsValid_TLS_1_0() throws Exception { + final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS); + response.addHeader(HttpHeaders.UPGRADE, "TLS/1.0"); + final ProtocolVersion result = switchStrategy.switchProtocol(response); + Assertions.assertEquals(TLS.V_1_0.getVersion(), result); + } + + @Test + void testSwitchToTlsValid_TLS_1_1() throws Exception { + final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS); + response.addHeader(HttpHeaders.UPGRADE, "TLS/1.1"); + final ProtocolVersion result = switchStrategy.switchProtocol(response); + Assertions.assertEquals(TLS.V_1_1.getVersion(), result); + } + + @Test + void testInvalidTlsFormat_NoSlash() { + final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS); + response.addHeader(HttpHeaders.UPGRADE, "TLSv1"); + Assertions.assertThrows(ProtocolException.class, () -> switchStrategy.switchProtocol(response), + "Invalid protocol: TLSv1"); + } + + @Test + void testSwitchToTlsValid_TLS_1() throws Exception { + final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS); + response.addHeader(HttpHeaders.UPGRADE, "TLS/1"); + final ProtocolVersion result = switchStrategy.switchProtocol(response); + Assertions.assertEquals(TLS.V_1_0.getVersion(), result); + } + + @Test + void testInvalidTlsFormat_MissingMajor() { + final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS); + response.addHeader(HttpHeaders.UPGRADE, "TLS/.1"); + Assertions.assertThrows(ProtocolException.class, () -> switchStrategy.switchProtocol(response), + "Invalid protocol: TLS/.1"); + } + + @Test + void testMultipleHttp11Tokens() { + final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS); + response.addHeader(HttpHeaders.UPGRADE, "HTTP/1.1, HTTP/1.1"); + Assertions.assertThrows(ProtocolException.class, () -> switchStrategy.switchProtocol(response), + "Invalid protocol switch response: no TLS version found"); + } + + @Test + void testMixedInvalidAndValidTokens() { + final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS); + response.addHeader(HttpHeaders.UPGRADE, "Crap, TLS/1.2, Invalid"); + Assertions.assertThrows(ProtocolException.class, () -> switchStrategy.switchProtocol(response), + "Invalid protocol: Crap"); + } } \ No newline at end of file