diff --git a/src/main/java/com/microsoft/azure/proton/transport/ws/impl/WebSocketHandlerImpl.java b/src/main/java/com/microsoft/azure/proton/transport/ws/impl/WebSocketHandlerImpl.java index bcfaddd..e8e1bb9 100644 --- a/src/main/java/com/microsoft/azure/proton/transport/ws/impl/WebSocketHandlerImpl.java +++ b/src/main/java/com/microsoft/azure/proton/transport/ws/impl/WebSocketHandlerImpl.java @@ -64,7 +64,9 @@ public Boolean validateUpgradeReply(ByteBuffer buffer) { buffer.get(data); retVal = webSocketUpgrade.validateUpgradeReply(data); - webSocketUpgrade = null; + if (retVal) { + webSocketUpgrade = null; + } } } diff --git a/src/main/java/com/microsoft/azure/proton/transport/ws/impl/WebSocketImpl.java b/src/main/java/com/microsoft/azure/proton/transport/ws/impl/WebSocketImpl.java index c7253d7..ec6f2e4 100644 --- a/src/main/java/com/microsoft/azure/proton/transport/ws/impl/WebSocketImpl.java +++ b/src/main/java/com/microsoft/azure/proton/transport/ws/impl/WebSocketImpl.java @@ -26,7 +26,12 @@ import org.apache.qpid.proton.engine.impl.TransportOutput; import org.apache.qpid.proton.engine.impl.TransportWrapper; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + public class WebSocketImpl implements WebSocket, TransportLayer { + private static final Logger TRACE_LOGGER = LoggerFactory.getLogger(WebSocketImpl.class); + private int maxFrameSize = (4 * 1024) + (16 * WebSocketHeader.MED_HEADER_LENGTH_MASKED); private boolean tailClosed = false; private final ByteBuffer inputBuffer; @@ -275,8 +280,13 @@ private boolean sendToUnderlyingInput() { private void processInput() throws TransportException { switch (webSocketState) { case PN_WS_CONNECTING: + inputBuffer.mark(); if (webSocketHandler.validateUpgradeReply(inputBuffer)) { webSocketState = WebSocketState.PN_WS_CONNECTED_FLOW; + } else { + // Input data was incomplete. Reset buffer position and wait for another call after more data arrives. + inputBuffer.reset(); + TRACE_LOGGER.warn("Websocket connecting response incomplete"); } inputBuffer.compact(); break; diff --git a/src/test/java/com/microsoft/azure/proton/transport/ws/impl/WebSocketImplTest.java b/src/test/java/com/microsoft/azure/proton/transport/ws/impl/WebSocketImplTest.java index ab9280d..3fb19b2 100644 --- a/src/test/java/com/microsoft/azure/proton/transport/ws/impl/WebSocketImplTest.java +++ b/src/test/java/com/microsoft/azure/proton/transport/ws/impl/WebSocketImplTest.java @@ -36,6 +36,8 @@ import org.mockito.stubbing.Answer; import java.nio.ByteBuffer; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; import java.security.SecureRandom; import java.util.Arrays; import java.util.HashMap; @@ -422,6 +424,60 @@ public void testPending_state_flow_empty_output() { verify(mockWebSocketHandler, times(0)).wrapBuffer((ByteBuffer) any(), (ByteBuffer) any()); } + @Test + public void testChunked_connection() { + init(); + + WebSocketHandlerImpl webSocketHandler = new WebSocketHandlerImpl(); + + WebSocketImpl webSocketImpl = new WebSocketImpl(); + webSocketImpl.configure(_hostName, _webSocketPath, _webSocketQuery, _webSocketPort, _webSocketProtocol, _additionalHeaders, webSocketHandler); + + TransportInput mockTransportInput = mock(TransportInput.class); + TransportOutput mockTransportOutput = mock(TransportOutput.class); + TransportWrapper transportWrapper = webSocketImpl.wrap(mockTransportInput, mockTransportOutput); + + assertTrue(webSocketImpl.getState() == WebSocket.WebSocketState.PN_WS_NOT_STARTED); + transportWrapper.pending(); + assertTrue(webSocketImpl.getState() == WebSocket.WebSocketState.PN_WS_CONNECTING); + + // Get the key that the upgrade verifier will expect + String request = webSocketHandler.createUpgradeRequest("fakehost", "fakepath", "fakequery", 9999, "fakeprotocol", null); + String[] lines = request.split("\r\n"); + String extractedKey = null; + for (String l : lines) { + if (l.startsWith("Sec-WebSocket-Key: ")) { + extractedKey = l.substring(19).trim(); + break; + } + } + String expectedKey = null; + try { + MessageDigest messageDigest = MessageDigest.getInstance("SHA-1"); + expectedKey = Base64.encodeBase64StringLocal( + messageDigest.digest((extractedKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11").getBytes())).trim(); + } catch (NoSuchAlgorithmException e) { + // can't happen since SHA-1 is a known digest + } + // Assemble a response that the upgrade verifier will accept + byte[] fakeInput = ("http/1.1 101 switching protocols\nupgrade websocket\nconnection upgrade\nsec-websocket-protocol fakeprotocol\nsec-websocket-accept " + + expectedKey).getBytes(); + + // Feed the response to the verifier, adding one byte at a time to simulate a response broken into chunks. + // This test inspired by an issue with the IBM JRE which for some reason returned the service's response in multiple pieces. + int i = 0; + ByteBuffer inputBuffer = transportWrapper.tail(); + for (i = 0; i < fakeInput.length - 1; i++) { + inputBuffer.put(fakeInput[i]); + transportWrapper.process(); + assertTrue(webSocketImpl.getState() == WebSocket.WebSocketState.PN_WS_CONNECTING); + } + // Add the last byte and the state should change. + inputBuffer.put(fakeInput[i]); + transportWrapper.process(); + assertTrue(webSocketImpl.getState() == WebSocket.WebSocketState.PN_WS_CONNECTED_FLOW); + } + @Test public void testPending_state_flow_output_not_empty() { init();