Skip to content

Commit

Permalink
Merge pull request #6183 from rossabaker/web-socket-handshake
Browse files Browse the repository at this point in the history
Move WebSocketHandshake to blaze-core
  • Loading branch information
rossabaker authored Mar 25, 2022
2 parents 8abffbc + 738b6e2 commit d2da1ed
Show file tree
Hide file tree
Showing 5 changed files with 197 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
/*
* Copyright 2014 http4s.org
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.http4s.blazecore.websocket

import cats.MonadThrow
import cats.effect.std.Random
import cats.syntax.all._
import org.http4s.crypto.Hash
import org.http4s.crypto.HashAlgorithm
import scodec.bits.ByteVector

import java.nio.charset.StandardCharsets._
import java.util.Base64

private[http4s] object WebSocketHandshake {

/** Creates a new [[ClientHandshaker]] */
def clientHandshaker[F[_]: MonadThrow](host: String, random: Random[F]): F[ClientHandshaker] =
for {
key <- random.nextBytes(16).map(Base64.getEncoder.encodeToString)
acceptKey <- genAcceptKey(key)
} yield new ClientHandshaker(host, key, acceptKey)

/** Provides the initial headers and a 16 byte Base64 encoded random key for websocket connections */
class ClientHandshaker(host: String, key: String, acceptKey: String) {

/** Initial headers to send to the server */
val initHeaders: List[(String, String)] =
("Host", host) :: ("Sec-WebSocket-Key", key) :: clientBaseHeaders

/** Check if the server response is a websocket handshake response */
def checkResponse(headers: Iterable[(String, String)]): Either[String, Unit] =
if (
!headers.exists { case (k, v) =>
k.equalsIgnoreCase("Connection") && valueContains("Upgrade", v)
}
)
Left("Bad Connection header")
else if (
!headers.exists { case (k, v) =>
k.equalsIgnoreCase("Upgrade") && v.equalsIgnoreCase("websocket")
}
)
Left("Bad Upgrade header")
else
headers
.find { case (k, _) => k.equalsIgnoreCase("Sec-WebSocket-Accept") }
.map {
case (_, v) if acceptKey === v => Either.unit
case (_, v) => Left(s"Invalid key: $v")
}
.getOrElse(Left("Missing Sec-WebSocket-Accept header"))
}

/** Checks the headers received from the client and if they are valid, generates response headers */
def serverHandshake(
headers: Iterable[(String, String)]
): Either[(Int, String), collection.Seq[(String, String)]] =
if (!headers.exists { case (k, _) => k.equalsIgnoreCase("Host") })
Left((-1, "Missing Host Header"))
else if (
!headers.exists { case (k, v) =>
k.equalsIgnoreCase("Connection") && valueContains("Upgrade", v)
}
)
Left((-1, "Bad Connection header"))
else if (
!headers.exists { case (k, v) =>
k.equalsIgnoreCase("Upgrade") && v.equalsIgnoreCase("websocket")
}
)
Left((-1, "Bad Upgrade header"))
else if (
!headers.exists { case (k, v) =>
k.equalsIgnoreCase("Sec-WebSocket-Version") && valueContains("13", v)
}
)
Left((-1, "Bad Websocket Version header"))
// we are past most of the 'just need them' headers
else
headers
.find { case (k, v) =>
k.equalsIgnoreCase("Sec-WebSocket-Key") && decodeLen(v) == 16
}
.map { case (_, key) =>
genAcceptKey[Either[Throwable, *]](key) match {
case Left(_) => Left((-1, "Bad Sec-WebSocket-Key header"))
case Right(acceptKey) =>
Right(
collection.Seq(
("Upgrade", "websocket"),
("Connection", "Upgrade"),
("Sec-WebSocket-Accept", acceptKey),
)
)
}
}
.getOrElse(Left((-1, "Bad Sec-WebSocket-Key header")))

/** Check if the headers contain an 'Upgrade: websocket' header */
def isWebSocketRequest(headers: Iterable[(String, String)]): Boolean =
headers.exists { case (k, v) =>
k.equalsIgnoreCase("Upgrade") && v.equalsIgnoreCase("websocket")
}

private def decodeLen(key: String): Int = Base64.getDecoder.decode(key).length

private def genAcceptKey[F[_]](str: String)(implicit F: MonadThrow[F]): F[String] = for {
data <- F.fromEither(ByteVector.encodeAscii(str))
digest <- Hash[F].digest(HashAlgorithm.SHA1, data ++ magicString)
} yield digest.toBase64

private[websocket] def valueContains(key: String, value: String): Boolean = {
val parts = value.split(",").map(_.trim)
parts.foldLeft(false)((b, s) =>
b || {
s.equalsIgnoreCase(key) ||
s.length > 1 &&
s.startsWith("\"") &&
s.endsWith("\"") &&
s.substring(1, s.length - 1).equalsIgnoreCase(key)
}
)
}

private val magicString =
ByteVector.view("258EAFA5-E914-47DA-95CA-C5AB0DC85B11".getBytes(US_ASCII))

private val clientBaseHeaders = List(
("Connection", "Upgrade"),
("Upgrade", "websocket"),
("Sec-WebSocket-Version", "13"),
)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Copyright 2014 http4s.org
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.http4s.blazecore.websocket

import cats.effect.IO
import cats.effect.std.Random
import org.http4s.Http4sSuite

class WebSocketHandshakeSuite extends Http4sSuite {

test("WebSocketHandshake should Be able to split multi value header keys") {
val totalValue = "keep-alive, Upgrade"
val values = List("upgrade", "Upgrade", "keep-alive", "Keep-alive")
assert(values.forall(v => WebSocketHandshake.valueContains(v, totalValue)))
}

test("WebSocketHandshake should do a round trip") {
for {
random <- Random.javaSecuritySecureRandom[IO]
client <- WebSocketHandshake.clientHandshaker[IO]("www.foo.com", random)
hs = client.initHeaders
valid = WebSocketHandshake.serverHandshake(hs)
_ = assert(valid.isRight)
response = client.checkResponse(valid.toOption.get)
_ = assert(response.isRight, response.swap.toOption.get)
} yield ()
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ import fs2.concurrent.SignallingRef
import org.http4s._
import org.http4s.blaze.pipeline.LeafBuilder
import org.http4s.blazecore.websocket.Http4sWSStage
import org.http4s.blazecore.websocket.WebSocketHandshake
import org.http4s.headers._
import org.http4s.websocket.WebSocketContext
import org.http4s.websocket.WebSocketHandshake
import org.typelevel.ci._
import org.typelevel.vault.Key

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ import java.nio.charset.StandardCharsets._
import java.util.Base64
import scala.util.Random

@deprecated(
"Retained for binary compatibility. Side-effecting. Only used by blaze-server.",
"0.23.13",
)
private[http4s] object WebSocketHandshake {

/** Creates a new [[ClientHandshaker]] */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package org.http4s.websocket

import org.http4s.Http4sSuite

@deprecated("Tests a deprecated feature", "0.23.13")
class WebSocketHandshakeSpec extends Http4sSuite {

test("WebSocketHandshake should Be able to split multi value header keys") {
Expand Down

0 comments on commit d2da1ed

Please sign in to comment.