diff --git a/httpclient5/src/main/java/org/apache/hc/client5/http/impl/ProtocolSwitchStrategy.java b/httpclient5/src/main/java/org/apache/hc/client5/http/impl/ProtocolSwitchStrategy.java index 6e5133fe74..061690eb95 100644 --- a/httpclient5/src/main/java/org/apache/hc/client5/http/impl/ProtocolSwitchStrategy.java +++ b/httpclient5/src/main/java/org/apache/hc/client5/http/impl/ProtocolSwitchStrategy.java @@ -27,15 +27,22 @@ package org.apache.hc.client5.http.impl; import java.util.Iterator; +import java.util.concurrent.atomic.AtomicReference; import org.apache.hc.core5.annotation.Internal; +import org.apache.hc.core5.http.FormattedHeader; +import org.apache.hc.core5.http.Header; import org.apache.hc.core5.http.HttpHeaders; import org.apache.hc.core5.http.HttpMessage; +import org.apache.hc.core5.http.HttpVersion; import org.apache.hc.core5.http.ParseException; import org.apache.hc.core5.http.ProtocolException; import org.apache.hc.core5.http.ProtocolVersion; -import org.apache.hc.core5.http.message.MessageSupport; +import org.apache.hc.core5.http.ProtocolVersionParser; import org.apache.hc.core5.http.ssl.TLS; +import org.apache.hc.core5.util.Args; +import org.apache.hc.core5.util.CharArrayBuffer; +import org.apache.hc.core5.util.Tokenizer; /** * Protocol switch handler. @@ -45,31 +52,106 @@ @Internal public final class ProtocolSwitchStrategy { - enum ProtocolSwitch { FAILURE, TLS } + private static final ProtocolVersionParser PROTOCOL_VERSION_PARSER = ProtocolVersionParser.INSTANCE; + + private static final Tokenizer TOKENIZER = Tokenizer.INSTANCE; + + private static final Tokenizer.Delimiter UPGRADE_TOKEN_DELIMITER = Tokenizer.delimiters(','); + + @FunctionalInterface + private interface HeaderConsumer { + void accept(CharSequence buffer, Tokenizer.Cursor cursor) throws ProtocolException; + } public ProtocolVersion switchProtocol(final HttpMessage response) throws ProtocolException { - final Iterator it = MessageSupport.iterateTokens(response, HttpHeaders.UPGRADE); + final AtomicReference tlsUpgrade = new AtomicReference<>(); - ProtocolVersion tlsUpgrade = null; - while (it.hasNext()) { - final String token = it.next(); - if (token.startsWith("TLS")) { - // TODO: Improve handling of HTTP protocol token once HttpVersion has a #parse method - try { - tlsUpgrade = token.length() == 3 ? TLS.V_1_2.getVersion() : TLS.parse(token.replace("TLS/", "TLSv")); - } catch (final ParseException ex) { - throw new ProtocolException("Invalid protocol: " + token); + parseHeaders(response, HttpHeaders.UPGRADE, (buffer, cursor) -> { + while (!cursor.atEnd()) { + TOKENIZER.skipWhiteSpace(buffer, cursor); + if (cursor.atEnd()) { + break; + } + final int tokenStart = cursor.getPos(); + TOKENIZER.parseToken(buffer, cursor, UPGRADE_TOKEN_DELIMITER); + final int tokenEnd = cursor.getPos(); + if (tokenStart < tokenEnd) { + final ProtocolVersion version = parseProtocolToken(buffer, tokenStart, tokenEnd); + if (version != null && "TLS".equalsIgnoreCase(version.getProtocol())) { + tlsUpgrade.set(version); + } } - } else if (token.equals("HTTP/1.1")) { - // TODO: Improve handling of HTTP protocol token once HttpVersion has a #parse method + if (!cursor.atEnd()) { + cursor.updatePos(cursor.getPos() + 1); + } + } + }); + + final ProtocolVersion result = tlsUpgrade.get(); + if (result != null) { + return result; + } else { + throw new ProtocolException("Invalid protocol switch response: no TLS version found"); + } + } + + private ProtocolVersion parseProtocolToken(final CharSequence buffer, final int start, final int end) + throws ProtocolException { + if (start >= end) { + return null; + } + + if (end - start == 3) { + final char c0 = buffer.charAt(start); + final char c1 = buffer.charAt(start + 1); + final char c2 = buffer.charAt(start + 2); + if ((c0 == 'T' || c0 == 't') && + (c1 == 'L' || c1 == 'l') && + (c2 == 'S' || c2 == 's')) { + return TLS.V_1_2.getVersion(); + } + } + + try { + final Tokenizer.Cursor cursor = new Tokenizer.Cursor(start, end); + final ProtocolVersion version = PROTOCOL_VERSION_PARSER.parse(buffer, cursor, null); + + if ("TLS".equalsIgnoreCase(version.getProtocol())) { + return version; + } else if (version.equals(HttpVersion.HTTP_1_1)) { + return null; } else { - throw new ProtocolException("Unsupported protocol: " + token); + throw new ProtocolException("Unsupported protocol or HTTP version: " + buffer.subSequence(start, end)); } + } catch (final ParseException ex) { + throw new ProtocolException("Invalid protocol: " + buffer.subSequence(start, end), ex); } - if (tlsUpgrade == null) { - throw new ProtocolException("Invalid protocol switch response"); + } + + private void parseHeaders(final HttpMessage message, final String name, final HeaderConsumer consumer) + throws ProtocolException { + Args.notNull(message, "Message headers"); + Args.notBlank(name, "Header name"); + final Iterator
it = message.headerIterator(name); + while (it.hasNext()) { + parseHeader(it.next(), consumer); } - return tlsUpgrade; } -} + private void parseHeader(final Header header, final HeaderConsumer consumer) throws ProtocolException { + Args.notNull(header, "Header"); + if (header instanceof FormattedHeader) { + final CharArrayBuffer buf = ((FormattedHeader) header).getBuffer(); + final Tokenizer.Cursor cursor = new Tokenizer.Cursor(0, buf.length()); + cursor.updatePos(((FormattedHeader) header).getValuePos()); + consumer.accept(buf, cursor); + } else { + final String value = header.getValue(); + if (value == null) { + return; + } + final Tokenizer.Cursor cursor = new Tokenizer.Cursor(0, value.length()); + consumer.accept(value, cursor); + } + } +} \ No newline at end of file diff --git a/httpclient5/src/test/java/org/apache/hc/client5/http/impl/TestProtocolSwitchStrategy.java b/httpclient5/src/test/java/org/apache/hc/client5/http/impl/TestProtocolSwitchStrategy.java index 9c8593a44b..2c7c322818 100644 --- a/httpclient5/src/test/java/org/apache/hc/client5/http/impl/TestProtocolSwitchStrategy.java +++ b/httpclient5/src/test/java/org/apache/hc/client5/http/impl/TestProtocolSwitchStrategy.java @@ -30,6 +30,7 @@ import org.apache.hc.core5.http.HttpResponse; import org.apache.hc.core5.http.HttpStatus; import org.apache.hc.core5.http.ProtocolException; +import org.apache.hc.core5.http.ProtocolVersion; import org.apache.hc.core5.http.message.BasicHttpResponse; import org.apache.hc.core5.http.ssl.TLS; import org.junit.jupiter.api.Assertions; @@ -37,7 +38,7 @@ import org.junit.jupiter.api.Test; /** - * Simple tests for {@link DefaultAuthenticationStrategy}. + * Simple tests for {@link ProtocolSwitchStrategy}. */ class TestProtocolSwitchStrategy { @@ -95,4 +96,120 @@ void testSwitchInvalid() { Assertions.assertThrows(ProtocolException.class, () -> switchStrategy.switchProtocol(response3)); } + @Test + void testNullToken() throws ProtocolException { + final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS); + response.addHeader(HttpHeaders.UPGRADE, "TLS,"); + response.addHeader(HttpHeaders.UPGRADE, null); + Assertions.assertEquals(TLS.V_1_2.getVersion(), switchStrategy.switchProtocol(response)); + } + + @Test + void testWhitespaceOnlyToken() throws ProtocolException { + final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS); + response.addHeader(HttpHeaders.UPGRADE, " , TLS"); + Assertions.assertEquals(TLS.V_1_2.getVersion(), switchStrategy.switchProtocol(response)); + } + + @Test + void testUnsupportedTlsVersion() throws Exception { + final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS); + response.addHeader(HttpHeaders.UPGRADE, "TLS/1.4"); + Assertions.assertEquals(new ProtocolVersion("TLS", 1, 4), switchStrategy.switchProtocol(response)); + } + + @Test + void testUnsupportedTlsMajorVersion() throws Exception { + final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS); + response.addHeader(HttpHeaders.UPGRADE, "TLS/2.0"); + Assertions.assertEquals(new ProtocolVersion("TLS", 2, 0), switchStrategy.switchProtocol(response)); + } + + @Test + void testUnsupportedHttpVersion() { + final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS); + response.addHeader(HttpHeaders.UPGRADE, "HTTP/2.0"); + Assertions.assertThrows(ProtocolException.class, () -> switchStrategy.switchProtocol(response), + "Unsupported HTTP version: HTTP/2.0"); + } + + @Test + void testInvalidTlsFormat() { + final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS); + response.addHeader(HttpHeaders.UPGRADE, "TLS/abc"); + Assertions.assertThrows(ProtocolException.class, () -> switchStrategy.switchProtocol(response), + "Invalid protocol: TLS/abc"); + } + + @Test + void testHttp11Only() { + final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS); + response.addHeader(HttpHeaders.UPGRADE, "HTTP/1.1"); + Assertions.assertThrows(ProtocolException.class, () -> switchStrategy.switchProtocol(response), + "Invalid protocol switch response: no TLS version found"); + } + + @Test + void testSwitchToTlsValid_TLS_1_2() throws Exception { + final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS); + response.addHeader(HttpHeaders.UPGRADE, "TLS/1.2"); + final ProtocolVersion result = switchStrategy.switchProtocol(response); + Assertions.assertEquals(TLS.V_1_2.getVersion(), result); + } + + @Test + void testSwitchToTlsValid_TLS_1_0() throws Exception { + final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS); + response.addHeader(HttpHeaders.UPGRADE, "TLS/1.0"); + final ProtocolVersion result = switchStrategy.switchProtocol(response); + Assertions.assertEquals(TLS.V_1_0.getVersion(), result); + } + + @Test + void testSwitchToTlsValid_TLS_1_1() throws Exception { + final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS); + response.addHeader(HttpHeaders.UPGRADE, "TLS/1.1"); + final ProtocolVersion result = switchStrategy.switchProtocol(response); + Assertions.assertEquals(TLS.V_1_1.getVersion(), result); + } + + @Test + void testInvalidTlsFormat_NoSlash() { + final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS); + response.addHeader(HttpHeaders.UPGRADE, "TLSv1"); + Assertions.assertThrows(ProtocolException.class, () -> switchStrategy.switchProtocol(response), + "Invalid protocol: TLSv1"); + } + + @Test + void testSwitchToTlsValid_TLS_1() throws Exception { + final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS); + response.addHeader(HttpHeaders.UPGRADE, "TLS/1"); + final ProtocolVersion result = switchStrategy.switchProtocol(response); + Assertions.assertEquals(TLS.V_1_0.getVersion(), result); + } + + @Test + void testInvalidTlsFormat_MissingMajor() { + final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS); + response.addHeader(HttpHeaders.UPGRADE, "TLS/.1"); + Assertions.assertThrows(ProtocolException.class, () -> switchStrategy.switchProtocol(response), + "Invalid protocol: TLS/.1"); + } + + @Test + void testMultipleHttp11Tokens() { + final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS); + response.addHeader(HttpHeaders.UPGRADE, "HTTP/1.1, HTTP/1.1"); + Assertions.assertThrows(ProtocolException.class, () -> switchStrategy.switchProtocol(response), + "Invalid protocol switch response: no TLS version found"); + } + + @Test + void testMixedInvalidAndValidTokens() { + final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_SWITCHING_PROTOCOLS); + response.addHeader(HttpHeaders.UPGRADE, "Crap, TLS/1.2, Invalid"); + Assertions.assertThrows(ProtocolException.class, () -> switchStrategy.switchProtocol(response), + "Invalid protocol: Crap"); + } } \ No newline at end of file