Skip to content

Commit

Permalink
KTOR-6970 Darwin, Java, JS: Propagate Sec-WebSocket-Protocol header (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
osipxd authored Jan 29, 2025
1 parent 15f0921 commit e4fd251
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ internal class JsClientEngine(
headers: Headers
): WebSocket {
val protocolHeaderNames = headers.names().filter { headerName ->
headerName.equals("sec-websocket-protocol", true)
headerName.equals(HttpHeaders.SecWebSocketProtocol, ignoreCase = true)
}
val protocols = protocolHeaderNames.mapNotNull { headers.getAll(it) }.flatten().toTypedArray()
return when {
Expand Down Expand Up @@ -108,10 +108,13 @@ internal class JsClientEngine(
throw cause
}

val protocol = socket.protocol.takeIf { it.isNotEmpty() }
val headers = if (protocol != null) headersOf(HttpHeaders.SecWebSocketProtocol, protocol) else Headers.Empty

return HttpResponseData(
HttpStatusCode.SwitchingProtocols,
requestTime,
Headers.Empty,
headers,
HttpProtocolVersion.HTTP_1_1,
session,
callContext
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ internal class JsClientEngine(
headers: Headers
): WebSocket {
val protocolHeaderNames = headers.names().filter { headerName ->
headerName.equals("sec-websocket-protocol", true)
headerName.equals(HttpHeaders.SecWebSocketProtocol, ignoreCase = true)
}
val protocols = protocolHeaderNames.mapNotNull { headers.getAll(it) }.flatten().toTypedArray()
return when {
Expand Down Expand Up @@ -116,10 +116,13 @@ internal class JsClientEngine(

val session = JsWebSocketSession(callContext, socket)

val protocol = socket.protocol.takeIf { it.isNotEmpty() }
val headers = if (protocol != null) headersOf(HttpHeaders.SecWebSocketProtocol, protocol) else Headers.Empty

return HttpResponseData(
HttpStatusCode.SwitchingProtocols,
requestTime,
Headers.Empty,
headers,
HttpProtocolVersion.HTTP_1_1,
session,
callContext
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
/*
* Copyright 2014-2022 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
* Copyright 2014-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
*/

package io.ktor.client.engine.darwin

import io.ktor.client.engine.darwin.internal.*
import io.ktor.client.request.*
import io.ktor.util.collections.*
import kotlinx.cinterop.*
import kotlinx.coroutines.*
import kotlinx.cinterop.UnsafeNumber
import kotlinx.coroutines.CompletableDeferred
import platform.Foundation.*
import platform.darwin.*
import kotlin.coroutines.*
import platform.darwin.NSObject
import kotlin.collections.set
import kotlin.coroutines.CoroutineContext

private const val HTTP_REQUESTS_INITIAL_CAPACITY = 32
private const val WS_REQUESTS_INITIAL_CAPACITY = 16
Expand Down Expand Up @@ -77,7 +78,7 @@ public class KtorNSURLSessionDelegate(
didOpenWithProtocol: String?
) {
val wsSession = webSocketSessions[webSocketTask] ?: return
wsSession.didOpen()
wsSession.didOpen(didOpenWithProtocol)
}

override fun URLSession(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2014-2024 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
* Copyright 2014-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
*/

package io.ktor.client.engine.darwin.internal
Expand All @@ -10,13 +10,20 @@ import io.ktor.http.*
import io.ktor.util.date.*
import io.ktor.utils.io.core.*
import io.ktor.websocket.*
import kotlinx.cinterop.*
import kotlinx.cinterop.ExperimentalForeignApi
import kotlinx.cinterop.UnsafeNumber
import kotlinx.cinterop.convert
import kotlinx.coroutines.*
import kotlinx.coroutines.channels.*
import kotlinx.io.*
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.channels.ReceiveChannel
import kotlinx.coroutines.channels.SendChannel
import kotlinx.coroutines.channels.consumeEach
import kotlinx.io.readByteArray
import platform.Foundation.*
import platform.darwin.*
import kotlin.coroutines.*
import platform.darwin.NSInteger
import kotlin.coroutines.CoroutineContext
import kotlin.coroutines.resume
import kotlin.coroutines.resumeWithException

@OptIn(UnsafeNumber::class, ExperimentalForeignApi::class)
internal class DarwinWebsocketSession(
Expand Down Expand Up @@ -157,11 +164,13 @@ internal class DarwinWebsocketSession(
coroutineContext.cancel()
}

fun didOpen() {
fun didOpen(protocol: String?) {
val headers = if (protocol != null) headersOf(HttpHeaders.SecWebSocketProtocol, protocol) else Headers.Empty

val response = HttpResponseData(
task.getStatusCode()?.let { HttpStatusCode.fromValue(it) } ?: HttpStatusCode.SwitchingProtocols,
requestTime,
Headers.Empty,
headers,
HttpProtocolVersion.HTTP_1_1,
this,
coroutineContext
Expand All @@ -177,7 +186,7 @@ internal class DarwinWebsocketSession(

// KTOR-7363 We want to proceed with the request if we get 401 Unauthorized status code
if (task.getStatusCode() == HttpStatusCode.Unauthorized.value) {
didOpen()
didOpen(protocol = null)
socketJob.complete()
return
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*
* Copyright 2014-2021 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
*/
* Copyright 2014-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
*/

package io.ktor.client.engine.java

Expand All @@ -15,14 +15,18 @@ import io.ktor.utils.io.*
import io.ktor.utils.io.core.*
import io.ktor.websocket.*
import kotlinx.coroutines.*
import kotlinx.coroutines.channels.*
import kotlinx.coroutines.future.*
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.channels.ReceiveChannel
import kotlinx.coroutines.channels.SendChannel
import kotlinx.coroutines.channels.consumeEach
import kotlinx.coroutines.future.asCompletableFuture
import kotlinx.coroutines.future.await
import java.net.http.*
import java.nio.*
import java.time.*
import java.nio.ByteBuffer
import java.time.Duration
import java.util.*
import java.util.concurrent.*
import kotlin.coroutines.*
import java.util.concurrent.CompletionStage
import kotlin.coroutines.CoroutineContext
import kotlin.text.String
import kotlin.text.toByteArray

Expand Down Expand Up @@ -92,9 +96,11 @@ internal class JavaHttpWebSocket(
FrameType.TEXT -> {
webSocket.sendText(String(frame.data), frame.fin).await()
}

FrameType.BINARY -> {
webSocket.sendBinary(frame.buffer, frame.fin).await()
}

FrameType.CLOSE -> {
val data = buildPacket { writeFully(frame.data) }
val code = data.readShort().toInt()
Expand All @@ -103,9 +109,11 @@ internal class JavaHttpWebSocket(
socketJob.complete()
return@launch
}

FrameType.PING -> {
webSocket.sendPing(frame.buffer).await()
}

FrameType.PONG -> {
webSocket.sendPong(frame.buffer).await()
}
Expand Down Expand Up @@ -153,11 +161,15 @@ internal class JavaHttpWebSocket(
}

var status = HttpStatusCode.SwitchingProtocols
var headers: Headers
try {
webSocket = builder.buildAsync(requestData.url.toURI(), this).await()
val protocol = webSocket.subprotocol?.takeIf { it.isNotEmpty() }
headers = if (protocol != null) headersOf(HttpHeaders.SecWebSocketProtocol, protocol) else Headers.Empty
} catch (cause: WebSocketHandshakeException) {
if (cause.response.statusCode() == HttpStatusCode.Unauthorized.value) {
status = HttpStatusCode.Unauthorized
headers = headersOf(cause.response.headers().map())
} else {
throw cause
}
Expand All @@ -166,7 +178,7 @@ internal class JavaHttpWebSocket(
return HttpResponseData(
status,
requestTime,
Headers.Empty,
headers,
HttpProtocolVersion.HTTP_1_1,
this,
callContext
Expand Down Expand Up @@ -217,3 +229,11 @@ internal class JavaHttpWebSocket(
socketJob.cancel()
}
}

private fun headersOf(map: Map<String, List<String>>): Headers = object : Headers {
override val caseInsensitiveName: Boolean = true
override fun getAll(name: String): List<String>? = map[name]
override fun names(): Set<String> = map.keys
override fun entries(): Set<Map.Entry<String, List<String>>> = map.entries
override fun isEmpty(): Boolean = map.isEmpty()
}
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,25 @@ class WebSocketTest : ClientLoader() {
}
}

@Test
fun testResponseContainsSecWebsocketProtocolHeader() = clientTests(except(ENGINES_WITHOUT_WS)) {
config {
install(WebSockets)
}

test { client ->
val session = client.webSocketSession("$TEST_WEBSOCKET_SERVER/websockets/sub-protocol") {
header(HttpHeaders.SecWebSocketProtocol, "test-protocol")
}

try {
assertEquals(session.call.response.headers[HttpHeaders.SecWebSocketProtocol], "test-protocol")
} finally {
session.close()
}
}
}

@Test
fun testIncomingOverflow() = clientTests(except(ENGINES_WITHOUT_WS)) {
config {
Expand Down

0 comments on commit e4fd251

Please sign in to comment.