diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6bb951087..bec6bb120 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -29,7 +29,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest] - scala: [3.0.2, 2.12.15, 2.13.8] + scala: [3.1.2, 2.12.15, 2.13.8] java: [temurin@8] runs-on: ${{ matrix.os }} steps: @@ -93,11 +93,11 @@ jobs: - name: Make target directories if: github.event_name != 'pull_request' && (startsWith(github.ref, 'refs/tags/v') || github.ref == 'refs/heads/main') - run: mkdir -p target examples/target http/target core/target testkit/target project/target + run: mkdir -p blaze-client/target blaze-server/target target examples/target http/target blaze-core/target core/target testkit/target project/target - name: Compress target directories if: github.event_name != 'pull_request' && (startsWith(github.ref, 'refs/tags/v') || github.ref == 'refs/heads/main') - run: tar cf targets.tar target examples/target http/target core/target testkit/target project/target + run: tar cf targets.tar blaze-client/target blaze-server/target target examples/target http/target blaze-core/target core/target testkit/target project/target - name: Upload target directories if: github.event_name != 'pull_request' && (startsWith(github.ref, 'refs/tags/v') || github.ref == 'refs/heads/main') @@ -150,12 +150,12 @@ jobs: ~/Library/Caches/Coursier/v1 key: ${{ runner.os }}-sbt-cache-v2-${{ hashFiles('**/*.sbt') }}-${{ hashFiles('project/build.properties') }} - - name: Download target directories (3.0.2) + - name: Download target directories (3.1.2) uses: actions/download-artifact@v2 with: - name: target-${{ matrix.os }}-${{ matrix.java }}-3.0.2 + name: target-${{ matrix.os }}-${{ matrix.java }}-3.1.2 - - name: Inflate target directories (3.0.2) + - name: Inflate target directories (3.1.2) run: | tar xf targets.tar rm targets.tar diff --git a/.scalafmt.blaze.conf b/.scalafmt.blaze.conf new file mode 100644 index 000000000..dd6deff12 --- /dev/null +++ b/.scalafmt.blaze.conf @@ -0,0 +1,20 @@ +version = 3.5.2 + +style = default + +maxColumn = 100 + +// Vertical alignment is pretty, but leads to bigger diffs +align.preset = none + +danglingParentheses.preset = false + +rewrite.rules = [ + AvoidInfix + RedundantBraces + RedundantParens + AsciiSortImports + PreferCurlyFors +] + +runner.dialect = scala213source3 diff --git a/.scalafmt.conf b/.scalafmt.conf index dd6deff12..cc782f8be 100644 --- a/.scalafmt.conf +++ b/.scalafmt.conf @@ -4,17 +4,41 @@ style = default maxColumn = 100 +// Docstring wrapping breaks doctests +docstrings.wrap = false + // Vertical alignment is pretty, but leads to bigger diffs align.preset = none -danglingParentheses.preset = false +danglingParentheses.preset = true rewrite.rules = [ AvoidInfix RedundantBraces RedundantParens - AsciiSortImports PreferCurlyFors + SortModifiers +] + +rewrite.sortModifiers.order = [ + override, implicit, private, protected, final, sealed, abstract, lazy ] -runner.dialect = scala213source3 +rewrite.trailingCommas.style = multiple + +project.excludeFilters = [ + "scalafix/*", + "scalafix-internal/input/*", + "scalafix-internal/output/*" +] + +runner.dialect = scala212 + +fileOverride { + "glob:**/scala-3/**/*.scala" { + runner.dialect = scala3 + } + "glob:**/scala-2.13/**/*.scala" { + runner.dialect = scala213 + } +} diff --git a/blaze-client/src/main/scala/org/http4s/blaze/client/BasicManager.scala b/blaze-client/src/main/scala/org/http4s/blaze/client/BasicManager.scala new file mode 100644 index 000000000..69aabaded --- /dev/null +++ b/blaze-client/src/main/scala/org/http4s/blaze/client/BasicManager.scala @@ -0,0 +1,39 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s +package blaze +package client + +import cats.effect._ +import cats.syntax.all._ +import org.http4s.client.RequestKey + +private final class BasicManager[F[_], A <: Connection[F]](builder: ConnectionBuilder[F, A])( + implicit F: Sync[F] +) extends ConnectionManager[F, A] { + def borrow(requestKey: RequestKey): F[NextConnection] = + builder(requestKey).map(NextConnection(_, fresh = true)) + + override def shutdown: F[Unit] = + F.unit + + override def invalidate(connection: A): F[Unit] = + F.delay(connection.shutdown()) + + override def release(connection: A): F[Unit] = + invalidate(connection) +} diff --git a/blaze-client/src/main/scala/org/http4s/blaze/client/BlazeClient.scala b/blaze-client/src/main/scala/org/http4s/blaze/client/BlazeClient.scala new file mode 100644 index 000000000..ccb617e14 --- /dev/null +++ b/blaze-client/src/main/scala/org/http4s/blaze/client/BlazeClient.scala @@ -0,0 +1,164 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s +package blaze +package client + +import cats.Applicative +import cats.effect.implicits._ +import cats.effect.kernel.Async +import cats.effect.kernel.Deferred +import cats.effect.kernel.Resource +import cats.effect.kernel.Resource.ExitCase +import cats.effect.std.Dispatcher +import cats.syntax.all._ +import org.http4s.blaze.util.TickWheelExecutor +import org.http4s.blazecore.ResponseHeaderTimeoutStage +import org.http4s.client.Client +import org.http4s.client.DefaultClient +import org.http4s.client.RequestKey +import org.http4s.client.UnexpectedStatus +import org.http4s.client.middleware.Retry +import org.http4s.client.middleware.RetryPolicy + +import java.net.SocketException +import java.nio.ByteBuffer +import java.util.concurrent.TimeoutException +import scala.concurrent.ExecutionContext +import scala.concurrent.duration._ + +/** Blaze client implementation */ +object BlazeClient { + private[blaze] def makeClient[F[_], A <: BlazeConnection[F]]( + manager: ConnectionManager[F, A], + responseHeaderTimeout: Duration, + requestTimeout: Duration, + scheduler: TickWheelExecutor, + ec: ExecutionContext, + retries: Int, + dispatcher: Dispatcher[F], + )(implicit F: Async[F]): Client[F] = { + val base = new BlazeClient[F, A]( + manager, + responseHeaderTimeout, + requestTimeout, + scheduler, + ec, + dispatcher, + ) + if (retries > 0) + Retry(retryPolicy(retries))(base) + else + base + } + + private[this] val retryNow = Duration.Zero.some + private def retryPolicy[F[_]](retries: Int): RetryPolicy[F] = { (req, result, n) => + result match { + case Left(_: SocketException) if n <= retries && req.isIdempotent => retryNow + case _ => None + } + } +} + +private class BlazeClient[F[_], A <: BlazeConnection[F]]( + manager: ConnectionManager[F, A], + responseHeaderTimeout: Duration, + requestTimeout: Duration, + scheduler: TickWheelExecutor, + ec: ExecutionContext, + dispatcher: Dispatcher[F], +)(implicit F: Async[F]) + extends DefaultClient[F] { + + override def run(req: Request[F]): Resource[F, Response[F]] = { + val key = RequestKey.fromRequest(req) + for { + requestTimeoutF <- scheduleRequestTimeout(key) + preparedConnection <- prepareConnection(key) + (conn, responseHeaderTimeoutF) = preparedConnection + timeout = responseHeaderTimeoutF.race(requestTimeoutF).map(_.merge) + responseResource <- Resource.eval(runRequest(conn, req, timeout)) + response <- responseResource + } yield response + } + + override def defaultOnError(req: Request[F])(resp: Response[F])(implicit + G: Applicative[F] + ): F[Throwable] = + resp.body.compile.drain.as(UnexpectedStatus(resp.status, req.method, req.uri)) + + private def prepareConnection(key: RequestKey): Resource[F, (A, F[TimeoutException])] = for { + conn <- borrowConnection(key) + responseHeaderTimeoutF <- addResponseHeaderTimeout(conn) + } yield (conn, responseHeaderTimeoutF) + + private def borrowConnection(key: RequestKey): Resource[F, A] = + Resource.makeCase(manager.borrow(key).map(_.connection)) { + case (conn, ExitCase.Canceled) => + // Currently we can't just release in case of cancellation, because cancellation clears the Write state of Http1Connection, so it might result in isRecycle=true even if there's a half-written request. + manager.invalidate(conn) + case (conn, _) => manager.release(conn) + } + + private def addResponseHeaderTimeout(conn: A): Resource[F, F[TimeoutException]] = + responseHeaderTimeout match { + case d: FiniteDuration => + Resource.apply( + Deferred[F, Either[Throwable, TimeoutException]].flatMap(timeout => + F.delay { + val stage = new ResponseHeaderTimeoutStage[ByteBuffer](d, scheduler, ec) + conn.spliceBefore(stage) + stage.init(e => dispatcher.unsafeRunSync(timeout.complete(e).void)) + (timeout.get.rethrow, F.delay(stage.removeStage())) + } + ) + ) + case _ => resourceNeverTimeoutException + } + + private def scheduleRequestTimeout(key: RequestKey): Resource[F, F[TimeoutException]] = + requestTimeout match { + case d: FiniteDuration => + Resource.pure(F.async[TimeoutException] { cb => + F.delay( + scheduler.schedule( + () => + cb( + Right(new TimeoutException(s"Request to $key timed out after ${d.toMillis} ms")) + ), + ec, + d, + ) + ).map(c => Some(F.delay(c.cancel()))) + }) + case _ => resourceNeverTimeoutException + } + + private def runRequest( + conn: A, + req: Request[F], + timeout: F[TimeoutException], + ): F[Resource[F, Response[F]]] = + conn + .runRequest(req, timeout) + .race(timeout.flatMap(F.raiseError[Resource[F, Response[F]]](_))) + .map(_.merge) + + private val resourceNeverTimeoutException = Resource.pure[F, F[TimeoutException]](F.never) + +} diff --git a/blaze-client/src/main/scala/org/http4s/blaze/client/BlazeClientBuilder.scala b/blaze-client/src/main/scala/org/http4s/blaze/client/BlazeClientBuilder.scala new file mode 100644 index 000000000..6dd2c4e79 --- /dev/null +++ b/blaze-client/src/main/scala/org/http4s/blaze/client/BlazeClientBuilder.scala @@ -0,0 +1,487 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s +package blaze +package client + +import cats.effect.kernel.Async +import cats.effect.kernel.Resource +import cats.effect.std.Dispatcher +import cats.syntax.all._ +import org.http4s.blaze.channel.ChannelOptions +import org.http4s.blaze.util.TickWheelExecutor +import org.http4s.blazecore.BlazeBackendBuilder +import org.http4s.blazecore.ExecutionContextConfig +import org.http4s.blazecore.tickWheelResource +import org.http4s.client.Client +import org.http4s.client.RequestKey +import org.http4s.client.defaults +import org.http4s.headers.`User-Agent` +import org.http4s.internal.BackendBuilder +import org.http4s.internal.SSLContextOption +import org.log4s.Logger +import org.log4s.getLogger + +import java.net.InetSocketAddress +import java.nio.channels.AsynchronousChannelGroup +import javax.net.ssl.SSLContext +import scala.concurrent.ExecutionContext +import scala.concurrent.duration._ + +/** Configure and obtain a BlazeClient + * @param responseHeaderTimeout duration between the submission of a request and the completion of the response header. Does not include time to read the response body. + * @param idleTimeout duration that a connection can wait without traffic being read or written before timeout + * @param requestTimeout maximum duration from the submission of a request through reading the body before a timeout. + * @param connectTimeout Duration a connection attempt times out after + * @param userAgent optional custom user agent header + * @param maxTotalConnections maximum connections the client will have at any specific time + * @param maxWaitQueueLimit maximum number requests waiting for a connection at any specific time + * @param maxConnectionsPerRequestKey Map of RequestKey to number of max connections + * @param sslContext Some custom `SSLContext`, or `None` if the default SSL context is to be lazily instantiated. + * @param checkEndpointIdentification require endpoint identification for secure requests according to RFC 2818, Section 3.1. If the certificate presented does not match the hostname of the request, the request fails with a CertificateException. This setting does not affect checking the validity of the cert via the sslContext's trust managers. + * @param maxResponseLineSize maximum length of the request line + * @param maxHeaderLength maximum length of headers + * @param maxChunkSize maximum size of chunked content chunks + * @param chunkBufferMaxSize Size of the buffer that is used when Content-Length header is not specified. + * @param parserMode lenient or strict parsing mode. The lenient mode will accept illegal chars but replaces them with � (0xFFFD) + * @param bufferSize internal buffer size of the blaze client + * @param executionContextConfig optional custom executionContext to run async computations. + * @param scheduler execution scheduler + * @param asynchronousChannelGroup custom AsynchronousChannelGroup to use other than the system default + * @param channelOptions custom socket options + * @param customDnsResolver customDnsResolver to use other than the system default + * @param retries the number of times an idempotent request that fails with a `SocketException` will be retried. This is a means to deal with connections that expired while in the pool. Retries happen immediately. The default is 2. For a more sophisticated retry strategy, see the [[org.http4s.client.middleware.Retry]] middleware. + * @param maxIdleDuration maximum time a connection can be idle and still + * be borrowed. Helps deal with connections that are closed while + * idling in the pool for an extended period. + */ +sealed abstract class BlazeClientBuilder[F[_]] private ( + val responseHeaderTimeout: Duration, + val idleTimeout: Duration, + val requestTimeout: Duration, + val connectTimeout: Duration, + val userAgent: Option[`User-Agent`], + val maxTotalConnections: Int, + val maxWaitQueueLimit: Int, + val maxConnectionsPerRequestKey: RequestKey => Int, + val sslContext: SSLContextOption, + val checkEndpointIdentification: Boolean, + val maxResponseLineSize: Int, + val maxHeaderLength: Int, + val maxChunkSize: Int, + val chunkBufferMaxSize: Int, + val parserMode: ParserMode, + val bufferSize: Int, + executionContextConfig: ExecutionContextConfig, + val scheduler: Resource[F, TickWheelExecutor], + val asynchronousChannelGroup: Option[AsynchronousChannelGroup], + val channelOptions: ChannelOptions, + val customDnsResolver: Option[RequestKey => Either[Throwable, InetSocketAddress]], + val retries: Int, + val maxIdleDuration: Duration, +)(implicit protected val F: Async[F]) + extends BlazeBackendBuilder[Client[F]] + with BackendBuilder[F, Client[F]] { + type Self = BlazeClientBuilder[F] + + protected final val logger: Logger = getLogger(this.getClass) + + @deprecated("Preserved for binary compatibility", "0.23.8") + private[BlazeClientBuilder] def this( + responseHeaderTimeout: Duration, + idleTimeout: Duration, + requestTimeout: Duration, + connectTimeout: Duration, + userAgent: Option[`User-Agent`], + maxTotalConnections: Int, + maxWaitQueueLimit: Int, + maxConnectionsPerRequestKey: RequestKey => Int, + sslContext: SSLContextOption, + checkEndpointIdentification: Boolean, + maxResponseLineSize: Int, + maxHeaderLength: Int, + maxChunkSize: Int, + chunkBufferMaxSize: Int, + parserMode: ParserMode, + bufferSize: Int, + executionContextConfig: ExecutionContextConfig, + scheduler: Resource[F, TickWheelExecutor], + asynchronousChannelGroup: Option[AsynchronousChannelGroup], + channelOptions: ChannelOptions, + customDnsResolver: Option[RequestKey => Either[Throwable, InetSocketAddress]], + F: Async[F], + ) = this( + responseHeaderTimeout = responseHeaderTimeout, + idleTimeout = idleTimeout, + requestTimeout = requestTimeout, + connectTimeout = connectTimeout, + userAgent = userAgent, + maxTotalConnections = maxTotalConnections, + maxWaitQueueLimit = maxWaitQueueLimit, + maxConnectionsPerRequestKey = maxConnectionsPerRequestKey, + sslContext = sslContext, + checkEndpointIdentification = checkEndpointIdentification, + maxResponseLineSize = maxResponseLineSize, + maxHeaderLength = maxHeaderLength, + maxChunkSize = maxChunkSize, + chunkBufferMaxSize = chunkBufferMaxSize, + parserMode = parserMode, + bufferSize = bufferSize, + executionContextConfig = executionContextConfig, + scheduler = scheduler, + asynchronousChannelGroup = asynchronousChannelGroup, + channelOptions = channelOptions, + customDnsResolver = customDnsResolver, + retries = 0, + maxIdleDuration = Duration.Inf, + )(F) + + private def copy( + responseHeaderTimeout: Duration = responseHeaderTimeout, + idleTimeout: Duration = idleTimeout, + requestTimeout: Duration = requestTimeout, + connectTimeout: Duration = connectTimeout, + userAgent: Option[`User-Agent`] = userAgent, + maxTotalConnections: Int = maxTotalConnections, + maxWaitQueueLimit: Int = maxWaitQueueLimit, + maxConnectionsPerRequestKey: RequestKey => Int = maxConnectionsPerRequestKey, + sslContext: SSLContextOption = sslContext, + checkEndpointIdentification: Boolean = checkEndpointIdentification, + maxResponseLineSize: Int = maxResponseLineSize, + maxHeaderLength: Int = maxHeaderLength, + maxChunkSize: Int = maxChunkSize, + chunkBufferMaxSize: Int = chunkBufferMaxSize, + parserMode: ParserMode = parserMode, + bufferSize: Int = bufferSize, + executionContextConfig: ExecutionContextConfig = executionContextConfig, + scheduler: Resource[F, TickWheelExecutor] = scheduler, + asynchronousChannelGroup: Option[AsynchronousChannelGroup] = asynchronousChannelGroup, + channelOptions: ChannelOptions = channelOptions, + customDnsResolver: Option[RequestKey => Either[Throwable, InetSocketAddress]] = + customDnsResolver, + retries: Int = retries, + maxIdleDuration: Duration = maxIdleDuration, + ): BlazeClientBuilder[F] = + new BlazeClientBuilder[F]( + responseHeaderTimeout = responseHeaderTimeout, + idleTimeout = idleTimeout, + requestTimeout = requestTimeout, + connectTimeout = connectTimeout, + userAgent = userAgent, + maxTotalConnections = maxTotalConnections, + maxWaitQueueLimit = maxWaitQueueLimit, + maxConnectionsPerRequestKey = maxConnectionsPerRequestKey, + sslContext = sslContext, + checkEndpointIdentification = checkEndpointIdentification, + maxResponseLineSize = maxResponseLineSize, + maxHeaderLength = maxHeaderLength, + maxChunkSize = maxChunkSize, + chunkBufferMaxSize = chunkBufferMaxSize, + parserMode = parserMode, + bufferSize = bufferSize, + executionContextConfig = executionContextConfig, + scheduler = scheduler, + asynchronousChannelGroup = asynchronousChannelGroup, + channelOptions = channelOptions, + customDnsResolver = customDnsResolver, + retries = retries, + maxIdleDuration = maxIdleDuration, + ) {} + + @deprecated( + "Do not use - always returns cats.effect.unsafe.IORuntime.global.compute." + + "There is no direct replacement - directly use Async[F].executionContext or your custom execution context", + "0.23.5", + ) + def executionContext: ExecutionContext = cats.effect.unsafe.IORuntime.global.compute + + def withResponseHeaderTimeout(responseHeaderTimeout: Duration): BlazeClientBuilder[F] = + copy(responseHeaderTimeout = responseHeaderTimeout) + + def withMaxHeaderLength(maxHeaderLength: Int): BlazeClientBuilder[F] = + copy(maxHeaderLength = maxHeaderLength) + + def withIdleTimeout(idleTimeout: Duration): BlazeClientBuilder[F] = + copy(idleTimeout = idleTimeout) + + def withRequestTimeout(requestTimeout: Duration): BlazeClientBuilder[F] = + copy(requestTimeout = requestTimeout) + + def withConnectTimeout(connectTimeout: Duration): BlazeClientBuilder[F] = + copy(connectTimeout = connectTimeout) + + def withUserAgentOption(userAgent: Option[`User-Agent`]): BlazeClientBuilder[F] = + copy(userAgent = userAgent) + def withUserAgent(userAgent: `User-Agent`): BlazeClientBuilder[F] = + withUserAgentOption(Some(userAgent)) + def withoutUserAgent: BlazeClientBuilder[F] = + withUserAgentOption(None) + + def withMaxTotalConnections(maxTotalConnections: Int): BlazeClientBuilder[F] = + copy(maxTotalConnections = maxTotalConnections) + + def withMaxWaitQueueLimit(maxWaitQueueLimit: Int): BlazeClientBuilder[F] = + copy(maxWaitQueueLimit = maxWaitQueueLimit) + + def withMaxConnectionsPerRequestKey( + maxConnectionsPerRequestKey: RequestKey => Int + ): BlazeClientBuilder[F] = + copy(maxConnectionsPerRequestKey = maxConnectionsPerRequestKey) + + /** Use the provided `SSLContext` when making secure calls */ + def withSslContext(sslContext: SSLContext): BlazeClientBuilder[F] = + copy(sslContext = SSLContextOption.Provided(sslContext)) + + /** Use an `SSLContext` obtained by `SSLContext.getDefault()` when making secure calls. + * + * Since 0.21, the creation is not deferred. + */ + def withDefaultSslContext: BlazeClientBuilder[F] = + withSslContext(SSLContext.getDefault()) + + /** Number of times to immediately retry idempotent requests that fail + * with a `SocketException`. + */ + def withRetries(retries: Int = retries): BlazeClientBuilder[F] = + copy(retries = retries) + + /** Time a connection can be idle and still be borrowed. Helps deal + * with connections that are closed while idling in the pool for an + * extended period. `Duration.Inf` means no timeout. + */ + def withMaxIdleDuration(maxIdleDuration: Duration = maxIdleDuration): BlazeClientBuilder[F] = + copy(maxIdleDuration = maxIdleDuration) + + /** Use some provided `SSLContext` when making secure calls, or disable secure calls with `None` */ + @deprecated( + message = + "Use withDefaultSslContext, withSslContext or withoutSslContext to set the SSLContext", + since = "0.22.0-M1", + ) + def withSslContextOption(sslContext: Option[SSLContext]): BlazeClientBuilder[F] = + copy(sslContext = + sslContext.fold[SSLContextOption](SSLContextOption.NoSSL)(SSLContextOption.Provided.apply) + ) + + /** Disable secure calls */ + def withoutSslContext: BlazeClientBuilder[F] = + copy(sslContext = SSLContextOption.NoSSL) + + def withCheckEndpointAuthentication(checkEndpointIdentification: Boolean): BlazeClientBuilder[F] = + copy(checkEndpointIdentification = checkEndpointIdentification) + + def withMaxResponseLineSize(maxResponseLineSize: Int): BlazeClientBuilder[F] = + copy(maxResponseLineSize = maxResponseLineSize) + + def withMaxChunkSize(maxChunkSize: Int): BlazeClientBuilder[F] = + copy(maxChunkSize = maxChunkSize) + + def withChunkBufferMaxSize(chunkBufferMaxSize: Int): BlazeClientBuilder[F] = + copy(chunkBufferMaxSize = chunkBufferMaxSize) + + def withParserMode(parserMode: ParserMode): BlazeClientBuilder[F] = + copy(parserMode = parserMode) + + def withBufferSize(bufferSize: Int): BlazeClientBuilder[F] = + copy(bufferSize = bufferSize) + + /** Configures the compute thread pool used to run async computations. + * + * This defaults to `cats.effect.Async[F].executionContext`. In + * almost all cases, it is desirable to use the default. + */ + def withExecutionContext(executionContext: ExecutionContext): BlazeClientBuilder[F] = + copy(executionContextConfig = ExecutionContextConfig.ExplicitContext(executionContext)) + + def withScheduler(scheduler: TickWheelExecutor): BlazeClientBuilder[F] = + copy(scheduler = scheduler.pure[Resource[F, *]]) + + def withAsynchronousChannelGroupOption( + asynchronousChannelGroup: Option[AsynchronousChannelGroup] + ): BlazeClientBuilder[F] = + copy(asynchronousChannelGroup = asynchronousChannelGroup) + def withAsynchronousChannelGroup( + asynchronousChannelGroup: AsynchronousChannelGroup + ): BlazeClientBuilder[F] = + withAsynchronousChannelGroupOption(Some(asynchronousChannelGroup)) + def withoutAsynchronousChannelGroup: BlazeClientBuilder[F] = + withAsynchronousChannelGroupOption(None) + + def withChannelOptions(channelOptions: ChannelOptions): BlazeClientBuilder[F] = + copy(channelOptions = channelOptions) + + def withCustomDnsResolver( + customDnsResolver: RequestKey => Either[Throwable, InetSocketAddress] + ): BlazeClientBuilder[F] = + copy(customDnsResolver = Some(customDnsResolver)) + + def resource: Resource[F, Client[F]] = + resourceWithState.map(_._1) + + /** Creates a blaze-client resource along with a [[BlazeClientState]] + * for monitoring purposes + */ + def resourceWithState: Resource[F, (Client[F], BlazeClientState[F])] = + for { + dispatcher <- Dispatcher[F] + scheduler <- scheduler + _ <- Resource.eval(verifyAllTimeoutsAccuracy(scheduler)) + _ <- Resource.eval(verifyTimeoutRelations()) + manager <- connectionManager(scheduler, dispatcher) + executionContext <- Resource.eval(executionContextConfig.getExecutionContext) + client = BlazeClient.makeClient( + manager = manager, + responseHeaderTimeout = responseHeaderTimeout, + requestTimeout = requestTimeout, + scheduler = scheduler, + ec = executionContext, + retries = retries, + dispatcher = dispatcher, + ) + + } yield (client, manager.state) + + private def verifyAllTimeoutsAccuracy(scheduler: TickWheelExecutor): F[Unit] = + for { + _ <- verifyTimeoutAccuracy(scheduler.tick, responseHeaderTimeout, "responseHeaderTimeout") + _ <- verifyTimeoutAccuracy(scheduler.tick, idleTimeout, "idleTimeout") + _ <- verifyTimeoutAccuracy(scheduler.tick, requestTimeout, "requestTimeout") + _ <- verifyTimeoutAccuracy(scheduler.tick, connectTimeout, "connectTimeout") + } yield () + + private def verifyTimeoutAccuracy( + tick: Duration, + timeout: Duration, + timeoutName: String, + ): F[Unit] = + F.delay { + val warningThreshold = 0.1 // 10% + val inaccuracy = tick / timeout + if (inaccuracy > warningThreshold) + logger.warn( + s"With current configuration, $timeoutName ($timeout) may be up to ${inaccuracy * 100}% longer than configured. " + + s"If timeout accuracy is important, consider using a scheduler with a shorter tick (currently $tick)." + ) + } + + private def verifyTimeoutRelations(): F[Unit] = + F.delay { + val advice = + s"It is recommended to configure responseHeaderTimeout < requestTimeout < idleTimeout " + + s"or disable some of them explicitly by setting them to Duration.Inf." + + if (responseHeaderTimeout.isFinite && responseHeaderTimeout >= requestTimeout) + logger.warn( + s"responseHeaderTimeout ($responseHeaderTimeout) is >= requestTimeout ($requestTimeout). $advice" + ) + + if (responseHeaderTimeout.isFinite && responseHeaderTimeout >= idleTimeout) + logger.warn( + s"responseHeaderTimeout ($responseHeaderTimeout) is >= idleTimeout ($idleTimeout). $advice" + ) + + if (requestTimeout.isFinite && requestTimeout >= idleTimeout) + logger.warn(s"requestTimeout ($requestTimeout) is >= idleTimeout ($idleTimeout). $advice") + } + + private def connectionManager(scheduler: TickWheelExecutor, dispatcher: Dispatcher[F])(implicit + F: Async[F] + ): Resource[F, ConnectionManager.Stateful[F, BlazeConnection[F]]] = { + val http1: ConnectionBuilder[F, BlazeConnection[F]] = + (requestKey: RequestKey) => + new Http1Support[F]( + sslContextOption = sslContext, + bufferSize = bufferSize, + asynchronousChannelGroup = asynchronousChannelGroup, + executionContextConfig = executionContextConfig, + scheduler = scheduler, + checkEndpointIdentification = checkEndpointIdentification, + maxResponseLineSize = maxResponseLineSize, + maxHeaderLength = maxHeaderLength, + maxChunkSize = maxChunkSize, + chunkBufferMaxSize = chunkBufferMaxSize, + parserMode = parserMode, + userAgent = userAgent, + channelOptions = channelOptions, + connectTimeout = connectTimeout, + dispatcher = dispatcher, + idleTimeout = idleTimeout, + getAddress = customDnsResolver.getOrElse(BlazeClientBuilder.getAddress(_)), + ).makeClient(requestKey) + + Resource.make( + executionContextConfig.getExecutionContext.flatMap(executionContext => + ConnectionManager.pool( + builder = http1, + maxTotal = maxTotalConnections, + maxWaitQueueLimit = maxWaitQueueLimit, + maxConnectionsPerRequestKey = maxConnectionsPerRequestKey, + responseHeaderTimeout = responseHeaderTimeout, + requestTimeout = requestTimeout, + executionContext = executionContext, + maxIdleDuration = maxIdleDuration, + ) + ) + )(_.shutdown) + } +} + +object BlazeClientBuilder { + + def apply[F[_]: Async]: BlazeClientBuilder[F] = + new BlazeClientBuilder[F]( + responseHeaderTimeout = Duration.Inf, + idleTimeout = 1.minute, + requestTimeout = defaults.RequestTimeout, + connectTimeout = defaults.ConnectTimeout, + userAgent = Some(`User-Agent`(ProductId("http4s-blaze", Some(BuildInfo.version)))), + maxTotalConnections = 10, + maxWaitQueueLimit = 256, + maxConnectionsPerRequestKey = Function.const(256), + sslContext = SSLContextOption.TryDefaultSSLContext, + checkEndpointIdentification = true, + maxResponseLineSize = 4096, + maxHeaderLength = 40960, + maxChunkSize = Int.MaxValue, + chunkBufferMaxSize = 1024 * 1024, + parserMode = ParserMode.Strict, + bufferSize = 8192, + executionContextConfig = ExecutionContextConfig.DefaultContext, + scheduler = tickWheelResource, + asynchronousChannelGroup = None, + channelOptions = ChannelOptions(Vector.empty), + customDnsResolver = None, + retries = 2, + maxIdleDuration = Duration.Inf, + ) {} + + @deprecated( + "Most users should use the default execution context provided. " + + "If you have a specific reason to use a custom one, use `.withExecutionContext`", + "0.23.5", + ) + def apply[F[_]: Async](executionContext: ExecutionContext): BlazeClientBuilder[F] = + BlazeClientBuilder[F].withExecutionContext(executionContext) + + def getAddress(requestKey: RequestKey): Either[Throwable, InetSocketAddress] = + requestKey match { + case RequestKey(s, auth) => + val port = auth.port.getOrElse(if (s == Uri.Scheme.https) 443 else 80) + val host = auth.host.value + Either.catchNonFatal(new InetSocketAddress(host, port)) + } +} diff --git a/blaze-client/src/main/scala/org/http4s/blaze/client/BlazeClientState.scala b/blaze-client/src/main/scala/org/http4s/blaze/client/BlazeClientState.scala new file mode 100644 index 000000000..3ded8bed2 --- /dev/null +++ b/blaze-client/src/main/scala/org/http4s/blaze/client/BlazeClientState.scala @@ -0,0 +1,28 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s.blaze.client + +import org.http4s.client.RequestKey + +import scala.collection.immutable + +trait BlazeClientState[F[_]] { + def isClosed: F[Boolean] + def allocated: F[immutable.Map[RequestKey, Int]] + def idleQueueDepth: F[immutable.Map[RequestKey, Int]] + def waitQueueDepth: F[Int] +} diff --git a/blaze-client/src/main/scala/org/http4s/blaze/client/BlazeConnection.scala b/blaze-client/src/main/scala/org/http4s/blaze/client/BlazeConnection.scala new file mode 100644 index 000000000..513a2746d --- /dev/null +++ b/blaze-client/src/main/scala/org/http4s/blaze/client/BlazeConnection.scala @@ -0,0 +1,29 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s +package blaze +package client + +import cats.effect.Resource +import org.http4s.blaze.pipeline.TailStage + +import java.nio.ByteBuffer +import java.util.concurrent.TimeoutException + +private trait BlazeConnection[F[_]] extends TailStage[ByteBuffer] with Connection[F] { + def runRequest(req: Request[F], cancellation: F[TimeoutException]): F[Resource[F, Response[F]]] +} diff --git a/blaze-client/src/main/scala/org/http4s/blaze/client/BlazeHttp1ClientParser.scala b/blaze-client/src/main/scala/org/http4s/blaze/client/BlazeHttp1ClientParser.scala new file mode 100644 index 000000000..8c6e818f0 --- /dev/null +++ b/blaze-client/src/main/scala/org/http4s/blaze/client/BlazeHttp1ClientParser.scala @@ -0,0 +1,92 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s.blaze.client + +import cats.syntax.all._ +import org.http4s._ +import org.http4s.blaze.http.parser.Http1ClientParser +import org.typelevel.ci.CIString + +import java.nio.ByteBuffer +import scala.collection.mutable.ListBuffer + +private[blaze] final class BlazeHttp1ClientParser( + maxResponseLineSize: Int, + maxHeaderLength: Int, + maxChunkSize: Int, + parserMode: ParserMode, +) extends Http1ClientParser( + maxResponseLineSize, + maxHeaderLength, + 2 * 1024, + maxChunkSize, + parserMode == ParserMode.Lenient, + ) { + private val headers = new ListBuffer[Header.Raw] + private var status: Status = _ + private var httpVersion: HttpVersion = _ + + override def reset(): Unit = { + headers.clear() + status = null + httpVersion = null + super.reset() + } + + def getHttpVersion(): HttpVersion = + if (httpVersion == null) HttpVersion.`HTTP/1.0` // TODO Questionable default + else httpVersion + + def doParseContent(buffer: ByteBuffer): Option[ByteBuffer] = Option(parseContent(buffer)) + + def getHeaders(): Headers = + if (headers.isEmpty) Headers.empty + else { + val hs = Headers(headers.result()) + headers.clear() // clear so we can accumulate trailing headers + hs + } + + def getStatus(): Status = + if (status == null) Status.InternalServerError + else status + + def finishedResponseLine(buffer: ByteBuffer): Boolean = + responseLineComplete() || parseResponseLine(buffer) + + def finishedHeaders(buffer: ByteBuffer): Boolean = + headersComplete() || parseHeaders(buffer) + + override protected def submitResponseLine( + code: Int, + reason: String, + scheme: String, + majorversion: Int, + minorversion: Int, + ): Unit = { + status = Status.fromInt(code).valueOr(throw _) + httpVersion = + if (majorversion == 1 && minorversion == 1) HttpVersion.`HTTP/1.1` + else if (majorversion == 1 && minorversion == 0) HttpVersion.`HTTP/1.0` + else HttpVersion.fromVersion(majorversion, minorversion).getOrElse(HttpVersion.`HTTP/1.0`) + } + + override protected def headerComplete(name: String, value: String): Boolean = { + headers += Header.Raw(CIString(name), value) + false + } +} diff --git a/blaze-client/src/main/scala/org/http4s/blaze/client/Connection.scala b/blaze-client/src/main/scala/org/http4s/blaze/client/Connection.scala new file mode 100644 index 000000000..2efa4dc41 --- /dev/null +++ b/blaze-client/src/main/scala/org/http4s/blaze/client/Connection.scala @@ -0,0 +1,36 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s +package blaze +package client + +import org.http4s.client.RequestKey + +private[client] trait Connection[F[_]] { + + /** Determine if the connection is closed and resources have been freed */ + def isClosed: Boolean + + /** Determine if the connection is in a state that it can be recycled for another request. */ + def isRecyclable: F[Boolean] + + /** Close down the connection, freeing resources and potentially aborting a [[Response]] */ + def shutdown(): Unit + + /** The key for requests we are able to serve */ + def requestKey: RequestKey +} diff --git a/blaze-client/src/main/scala/org/http4s/blaze/client/ConnectionBuilder.scala b/blaze-client/src/main/scala/org/http4s/blaze/client/ConnectionBuilder.scala new file mode 100644 index 000000000..0b6c0ad89 --- /dev/null +++ b/blaze-client/src/main/scala/org/http4s/blaze/client/ConnectionBuilder.scala @@ -0,0 +1,23 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s.blaze.client + +import org.http4s.client.RequestKey + +private[client] trait ConnectionBuilder[F[_], A <: Connection[F]] { + def apply(key: RequestKey): F[A] +} diff --git a/blaze-client/src/main/scala/org/http4s/blaze/client/ConnectionManager.scala b/blaze-client/src/main/scala/org/http4s/blaze/client/ConnectionManager.scala new file mode 100644 index 000000000..2f5ada5e1 --- /dev/null +++ b/blaze-client/src/main/scala/org/http4s/blaze/client/ConnectionManager.scala @@ -0,0 +1,125 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s +package blaze +package client + +import cats.effect._ +import cats.effect.std.Semaphore +import cats.syntax.all._ +import org.http4s.client.RequestKey + +import scala.concurrent.ExecutionContext +import scala.concurrent.duration.Duration + +/** Type that is responsible for the client lifecycle + * + * The [[ConnectionManager]] is a general wrapper around a [[ConnectionBuilder]] + * that can pool resources in order to conserve resources such as socket connections, + * CPU time, SSL handshakes, etc. Because it can contain significant resources it + * must have a mechanism to free resources associated with it. + */ +private trait ConnectionManager[F[_], A <: Connection[F]] { + + /** Bundle of the connection and whether its new or not */ + // Sealed, rather than final, because SI-4440. + sealed case class NextConnection(connection: A, fresh: Boolean) + + /** Shutdown this client, closing any open connections and freeing resources */ + def shutdown: F[Unit] + + /** Get a connection for the provided request key. */ + def borrow(requestKey: RequestKey): F[NextConnection] + + /** Release a connection. The connection manager may choose to keep the connection for + * subsequent calls to [[borrow]], or dispose of the connection. + */ + def release(connection: A): F[Unit] + + /** Invalidate a connection, ensuring that its resources are freed. The connection + * manager may not return this connection on another borrow. + */ + def invalidate(connection: A): F[Unit] +} + +private object ConnectionManager { + trait Stateful[F[_], A <: Connection[F]] extends ConnectionManager[F, A] { + def state: BlazeClientState[F] + } + + /** Create a [[ConnectionManager]] that creates new connections on each request + * + * @param builder generator of new connections + */ + def basic[F[_]: Sync, A <: Connection[F]]( + builder: ConnectionBuilder[F, A] + ): ConnectionManager[F, A] = + new BasicManager[F, A](builder) + + /** Create a [[ConnectionManager]] that will attempt to recycle connections + * + * @param builder generator of new connections + * @param maxTotal max total connections + * @param maxWaitQueueLimit maximum number requests waiting for a connection at any specific time + * @param maxConnectionsPerRequestKey Map of RequestKey to number of max connections + * @param executionContext `ExecutionContext` where async operations will execute + */ + def pool[F[_]: Async, A <: Connection[F]]( + builder: ConnectionBuilder[F, A], + maxTotal: Int, + maxWaitQueueLimit: Int, + maxConnectionsPerRequestKey: RequestKey => Int, + responseHeaderTimeout: Duration, + requestTimeout: Duration, + executionContext: ExecutionContext, + maxIdleDuration: Duration, + ): F[ConnectionManager.Stateful[F, A]] = + Semaphore(1).map { semaphore => + new PoolManager[F, A]( + builder, + maxTotal, + maxWaitQueueLimit, + maxConnectionsPerRequestKey, + responseHeaderTimeout, + requestTimeout, + semaphore, + executionContext, + maxIdleDuration, + ) + } + + @deprecated("Preserved for binary compatibility", "0.23.8") + def pool[F[_]: Async, A <: Connection[F]]( + builder: ConnectionBuilder[F, A], + maxTotal: Int, + maxWaitQueueLimit: Int, + maxConnectionsPerRequestKey: RequestKey => Int, + responseHeaderTimeout: Duration, + requestTimeout: Duration, + executionContext: ExecutionContext, + ): F[ConnectionManager.Stateful[F, A]] = + pool( + builder, + maxTotal, + maxWaitQueueLimit, + maxConnectionsPerRequestKey, + responseHeaderTimeout, + requestTimeout, + executionContext, + Duration.Inf, + ) +} diff --git a/blaze-client/src/main/scala/org/http4s/blaze/client/Http1Connection.scala b/blaze-client/src/main/scala/org/http4s/blaze/client/Http1Connection.scala new file mode 100644 index 000000000..3ea0634e5 --- /dev/null +++ b/blaze-client/src/main/scala/org/http4s/blaze/client/Http1Connection.scala @@ -0,0 +1,506 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s +package blaze +package client + +import cats.effect.implicits._ +import cats.effect.kernel.Async +import cats.effect.kernel.Deferred +import cats.effect.kernel.Outcome +import cats.effect.kernel.Resource +import cats.effect.std.Dispatcher +import cats.syntax.all._ +import fs2._ +import org.http4s.Uri.Authority +import org.http4s.Uri.RegName +import org.http4s.blaze.pipeline.Command.EOF +import org.http4s.blazecore.Http1Stage +import org.http4s.blazecore.IdleTimeoutStage +import org.http4s.blazecore.util.Http1Writer +import org.http4s.client.RequestKey +import org.http4s.headers.Host +import org.http4s.headers.`Content-Length` +import org.http4s.headers.`User-Agent` +import org.http4s.headers.{Connection => HConnection} +import org.http4s.internal.CharPredicate +import org.http4s.util.StringWriter +import org.http4s.util.Writer +import org.typelevel.vault._ + +import java.net.SocketException +import java.nio.ByteBuffer +import java.util.concurrent.TimeoutException +import java.util.concurrent.atomic.AtomicReference +import scala.annotation.tailrec +import scala.concurrent.ExecutionContext +import scala.concurrent.Future +import scala.util.Failure +import scala.util.Success + +private final class Http1Connection[F[_]]( + val requestKey: RequestKey, + override protected val executionContext: ExecutionContext, + maxResponseLineSize: Int, + maxHeaderLength: Int, + maxChunkSize: Int, + override val chunkBufferMaxSize: Int, + parserMode: ParserMode, + userAgent: Option[`User-Agent`], + idleTimeoutStage: Option[IdleTimeoutStage[ByteBuffer]], + override val dispatcher: Dispatcher[F], +)(implicit protected val F: Async[F]) + extends Http1Stage[F] + with BlazeConnection[F] { + import Http1Connection._ + import Resource.ExitCase + + override def name: String = getClass.getName + private val parser = + new BlazeHttp1ClientParser(maxResponseLineSize, maxHeaderLength, maxChunkSize, parserMode) + + private val stageState = new AtomicReference[State](ReadIdle(None)) + private val closed = Deferred.unsafe[F, Unit] + + override def isClosed: Boolean = + stageState.get match { + case Error(_) => true + case _ => false + } + + override def isRecyclable: F[Boolean] = + F.delay(stageState.get match { + case ReadIdle(_) => true + case _ => false + }) + + override def shutdown(): Unit = stageShutdown() + + override def stageShutdown(): Unit = shutdownWithError(EOF) + + override protected def fatalError(t: Throwable, msg: String): Unit = { + val realErr = t match { + case _: TimeoutException => EOF + case EOF => EOF + case t => + logger.error(t)(s"Fatal Error: $msg") + t + } + shutdownWithError(realErr) + } + + @tailrec + private def shutdownWithError(t: Throwable): Unit = + stageState.get match { + // If we have a real error, lets put it here. + case st @ Error(EOF) if t != EOF => + if (!stageState.compareAndSet(st, Error(t))) shutdownWithError(t) + else { + closePipeline(Some(t)) + } + + case Error(_) => // NOOP: already shut down + case x => + if (!stageState.compareAndSet(x, Error(t))) shutdownWithError(t) + else { + val cmd = t match { + case EOF => None + case _ => Some(t) + } + closePipeline(cmd) + super.stageShutdown() + dispatcher.unsafeRunAndForget(closed.complete(())) + } + } + + @tailrec + def resetRead(): Unit = { + val state = stageState.get() + val nextState = state match { + case ReadActive => + // idleTimeout is activated when entering ReadWrite state, remains active throughout Read and Write and is deactivated when entering the Idle state + idleTimeoutStage.foreach(_.cancelTimeout()) + Some(ReadIdle(Some(startIdleRead()))) + case _ => None + } + + nextState match { + case Some(n) => if (stageState.compareAndSet(state, n)) parser.reset() else resetRead() + case None => () + } + } + + // #4798 We read from the channel while the connection is idle, in order to receive an EOF when the connection gets closed. + private def startIdleRead(): Future[ByteBuffer] = { + val f = channelRead() + f.onComplete { + case Failure(t) => shutdownWithError(t) + case _ => + }(executionContext) + f + } + + def runRequest(req: Request[F], cancellation: F[TimeoutException]): F[Resource[F, Response[F]]] = + F.defer[Resource[F, Response[F]]] { + stageState.get match { + case i @ ReadIdle(idleRead) => + if (stageState.compareAndSet(i, ReadActive)) { + logger.debug(s"Connection was idle. Running.") + executeRequest(req, cancellation, idleRead) + } else { + logger.debug(s"Connection changed state since checking it was idle. Looping.") + runRequest(req, cancellation) + } + case ReadActive => + logger.error(s"Tried to run a request already in running state.") + F.raiseError(InProgressException) + case Error(e) => + logger.debug(s"Tried to run a request in closed/error state: $e") + F.raiseError(e) + } + } + + override protected def doParseContent(buffer: ByteBuffer): Option[ByteBuffer] = + parser.doParseContent(buffer) + + override protected def contentComplete(): Boolean = parser.contentComplete() + + private def executeRequest( + req: Request[F], + cancellation: F[TimeoutException], + idleRead: Option[Future[ByteBuffer]], + ): F[Resource[F, Response[F]]] = { + logger.debug(s"Beginning request: ${req.method} ${req.uri}") + validateRequest(req) match { + case Left(e) => + F.raiseError(e) + case Right(req) => + F.defer[Resource[F, Response[F]]] { + val initWriterSize: Int = 512 + val rr: StringWriter = new StringWriter(initWriterSize) + val isServer: Boolean = false + + // Side Effecting Code + encodeRequestLine(req, rr) + Http1Stage.encodeHeaders(req.headers.headers, rr, isServer) + if (userAgent.nonEmpty && !req.headers.contains[`User-Agent`]) + rr << userAgent.get << "\r\n" + + val mustClose: Boolean = req.headers.get[HConnection] match { + case Some(conn) => checkCloseConnection(conn, rr) + case None => getHttpMinor(req) == 0 + } + + val writeRequest: F[Boolean] = getChunkEncoder(req, mustClose, rr) + .write(rr, req.body) + .onError { + case EOF => F.delay(shutdownWithError(EOF)) + case t => + F.delay(logger.error(t)("Error rendering request")) >> F.delay(shutdownWithError(t)) + } + + val idleTimeoutF: F[TimeoutException] = idleTimeoutStage match { + case Some(stage) => F.async_[TimeoutException](stage.setTimeout) + case None => F.never[TimeoutException] + } + + idleTimeoutF.start.flatMap { timeoutFiber => + // the request timeout, the response header timeout, and the idle timeout + val mergedTimeouts = cancellation.race(timeoutFiber.joinWithNever).map(_.merge) + F.bracketCase( + writeRequest.start + )(writeFiber => + receiveResponse( + mustClose, + doesntHaveBody = req.method == Method.HEAD, + mergedTimeouts.map(Left(_)), + idleRead, + ).map(response => + // We need to finish writing before we attempt to recycle the connection. We consider three scenarios: + // - The write already finished before we got the response. This is the most common scenario. `join` completes immediately. + // - The whole request was already transmitted and we received the response from the server, but we did not yet notice that the write is complete. This is sort of a race, it happens frequently enough when load testing. We need to wait just a moment for the `join` to finish. + // - The server decided to reject our request before we finished sending it. The server responded (typically with an error) and closed the connection. We shouldn't wait for the `writeFiber`. This connection needs to be disposed. + Resource.make(F.pure(response))(_ => + writeFiber.join.attempt.race(closed.get >> writeFiber.cancel.start).void + ) + ) + ) { + case (_, Outcome.Succeeded(_)) => F.unit + case (_, Outcome.Canceled()) => F.delay(shutdown()) + case (_, Outcome.Errored(e)) => F.delay(shutdownWithError(e)) + }.race(mergedTimeouts) + .flatMap { + case Left(r) => F.pure(r) + case Right(t) => F.raiseError(t) + } + } + }.adaptError { case EOF => + new SocketException(s"HTTP connection closed: ${requestKey}") + } + } + } + + private def receiveResponse( + closeOnFinish: Boolean, + doesntHaveBody: Boolean, + idleTimeoutS: F[Either[Throwable, Unit]], + idleRead: Option[Future[ByteBuffer]], + ): F[Response[F]] = + F.async_[Response[F]] { cb => + val read = idleRead.getOrElse(channelRead()) + handleRead(read, cb, closeOnFinish, doesntHaveBody, "Initial Read", idleTimeoutS) + } + + // this method will get some data, and try to continue parsing using the implicit ec + private def readAndParsePrelude( + cb: Callback[Response[F]], + closeOnFinish: Boolean, + doesntHaveBody: Boolean, + phase: String, + idleTimeoutS: F[Either[Throwable, Unit]], + ): Unit = + handleRead(channelRead(), cb, closeOnFinish, doesntHaveBody, phase, idleTimeoutS) + + private def handleRead( + read: Future[ByteBuffer], + cb: Callback[Response[F]], + closeOnFinish: Boolean, + doesntHaveBody: Boolean, + phase: String, + idleTimeoutS: F[Either[Throwable, Unit]], + ): Unit = + read.onComplete { + case Success(buff) => parsePrelude(buff, closeOnFinish, doesntHaveBody, cb, idleTimeoutS) + case Failure(EOF) => + stageState.get match { + case Error(e) => cb(Left(e)) + case _ => + shutdown() + cb(Left(EOF)) + } + + case Failure(t) => + fatalError(t, s"Error during phase: $phase") + cb(Left(t)) + }(executionContext) + + private def parsePrelude( + buffer: ByteBuffer, + closeOnFinish: Boolean, + doesntHaveBody: Boolean, + cb: Callback[Response[F]], + idleTimeoutS: F[Either[Throwable, Unit]], + ): Unit = + try + if (!parser.finishedResponseLine(buffer)) + readAndParsePrelude( + cb, + closeOnFinish, + doesntHaveBody, + "Response Line Parsing", + idleTimeoutS, + ) + else if (!parser.finishedHeaders(buffer)) + readAndParsePrelude(cb, closeOnFinish, doesntHaveBody, "Header Parsing", idleTimeoutS) + else + parsePreludeFinished(buffer, closeOnFinish, doesntHaveBody, cb, idleTimeoutS) + catch { + case t: Throwable => + logger.error(t)("Error during client request decode loop") + cb(Left(t)) + } + + // it's called when headers and response line parsing are finished + private def parsePreludeFinished( + buffer: ByteBuffer, + closeOnFinish: Boolean, + doesntHaveBody: Boolean, + cb: Callback[Response[F]], + idleTimeoutS: F[Either[Throwable, Unit]], + ): Unit = { + // Get headers and determine if we need to close + val headers: Headers = parser.getHeaders() + val status: Status = parser.getStatus() + val httpVersion: HttpVersion = parser.getHttpVersion() + + val (attributes, body): (Vault, EntityBody[F]) = if (doesntHaveBody) { + // responses to HEAD requests do not have a body + cleanUpAfterReceivingResponse(closeOnFinish, headers) + (Vault.empty, EmptyBody) + } else { + // We are to the point of parsing the body and then cleaning up + val (rawBody, _): (EntityBody[F], () => Future[ByteBuffer]) = + collectBodyFromParser(buffer, onEofWhileReadingBody _) + + // to collect the trailers we need a cleanup helper and an effect in the attribute map + val (trailerCleanup, attributes): (() => Unit, Vault) = + if (parser.getHttpVersion().minor == 1 && parser.isChunked()) { + val trailers = new AtomicReference(Headers.empty) + + val attrs = Vault.empty.insert[F[Headers]]( + Message.Keys.TrailerHeaders[F], + F.defer { + if (parser.contentComplete()) F.pure(trailers.get()) + else + F.raiseError( + new IllegalStateException( + "Attempted to collect trailers before the body was complete." + ) + ) + }, + ) + + (() => trailers.set(parser.getHeaders()), attrs) + } else + (() => (), Vault.empty) + + if (parser.contentComplete()) { + trailerCleanup() + cleanUpAfterReceivingResponse(closeOnFinish, headers) + attributes -> rawBody + } else + attributes -> rawBody.onFinalizeCaseWeak { + case ExitCase.Succeeded => + F.delay { trailerCleanup(); cleanUpAfterReceivingResponse(closeOnFinish, headers); } + .evalOn(executionContext) + case ExitCase.Errored(_) | ExitCase.Canceled => + F.delay { + trailerCleanup(); cleanUpAfterReceivingResponse(closeOnFinish, headers); + stageShutdown() + }.evalOn(executionContext) + } + } + + cb( + Right( + Response[F]( + status = status, + httpVersion = httpVersion, + headers = headers, + body = body.interruptWhen(idleTimeoutS), + attributes = attributes, + ) + ) + ) + } + + // It's called when an EOF is received while reading response body. + // It's responsible for deciding if the EOF should be considered an error or an indication of the end of the body. + private def onEofWhileReadingBody(): Either[Throwable, Option[Chunk[Byte]]] = + stageState.get match { // if we don't have a length, EOF signals the end of the body. + case Error(e) if e != EOF => Either.left(e) + case _ => + if (parser.definedContentLength() || parser.isChunked()) + Either.left(InvalidBodyException("Received premature EOF.")) + else Either.right(None) + } + + private def cleanUpAfterReceivingResponse(closeOnFinish: Boolean, headers: Headers): Unit = + if (closeOnFinish || headers.get[HConnection].exists(_.hasClose)) { + logger.debug("Message body complete. Shutting down.") + stageShutdown() + } else { + logger.debug(s"Resetting $name after completing request.") + resetRead() + } + + // /////////////////////// Private helpers ///////////////////////// + + /** Validates the request, attempting to fix it if possible, + * returning an Exception if invalid, None otherwise + */ + @tailrec private def validateRequest(req: Request[F]): Either[Exception, Request[F]] = { + val minor: Int = getHttpMinor(req) + + minor match { + // If we are HTTP/1.0, make sure HTTP/1.0 has no body or a Content-Length header + case 0 if !req.headers.contains[`Content-Length`] => + logger.warn(s"Request $req is HTTP/1.0 but lacks a length header. Transforming to HTTP/1.1") + validateRequest(req.withHttpVersion(HttpVersion.`HTTP/1.1`)) + + case 1 if req.uri.host.isEmpty => // this is unlikely if not impossible + // Ensure we have a host header for HTTP/1.1 + req.headers.get[Host] match { + case Some(host) => + val newAuth = req.uri.authority match { + case Some(auth) => auth.copy(host = RegName(host.host), port = host.port) + case None => Authority(host = RegName(host.host), port = host.port) + } + validateRequest(req.withUri(req.uri.copy(authority = Some(newAuth)))) + + case None if req.headers.contains[`Content-Length`] => + // translate to HTTP/1.0 + validateRequest(req.withHttpVersion(HttpVersion.`HTTP/1.0`)) + + case None => + Left(new IllegalArgumentException("Host header required for HTTP/1.1 request")) + } + + case _ if req.uri.path == Uri.Path.empty => + Right(req.withUri(req.uri.copy(path = Uri.Path.Root))) + + case _ if req.uri.path.renderString.exists(ForbiddenUriCharacters) => + Left(new IllegalArgumentException(s"Invalid URI path: ${req.uri.path}")) + + case _ => + Right(req) // All appears to be well + } + } + + private def getChunkEncoder( + req: Request[F], + closeHeader: Boolean, + rr: StringWriter, + ): Http1Writer[F] = + getEncoder(req, rr, getHttpMinor(req), closeHeader) +} + +private object Http1Connection { + case object InProgressException extends Exception("Stage has request in progress") + + // ADT representing the state that the ClientStage can be in + private sealed trait State + private final case class ReadIdle(idleRead: Option[Future[ByteBuffer]]) extends State + private case object ReadActive extends State + private final case class Error(exc: Throwable) extends State + + private def getHttpMinor[F[_]](req: Request[F]): Int = req.httpVersion.minor + + private def encodeRequestLine[F[_]](req: Request[F], writer: Writer): writer.type = { + val uri = req.uri + writer << req.method << ' ' << uri.toOriginForm << ' ' << req.httpVersion << "\r\n" + if ( + getHttpMinor(req) == 1 && + !req.headers.contains[Host] + ) { // need to add the host header for HTTP/1.1 + uri.host match { + case Some(host) => + writer << "Host: " << host.value + if (uri.port.isDefined) writer << ':' << uri.port.get + writer << "\r\n" + + case None => + // TODO: do we want to do this by exception? + throw new IllegalArgumentException("Request URI must have a host.") + } + writer + } else writer + } + + private val ForbiddenUriCharacters = CharPredicate(0x0.toChar, '\r', '\n') + +} diff --git a/blaze-client/src/main/scala/org/http4s/blaze/client/Http1Support.scala b/blaze-client/src/main/scala/org/http4s/blaze/client/Http1Support.scala new file mode 100644 index 000000000..f88212ac7 --- /dev/null +++ b/blaze-client/src/main/scala/org/http4s/blaze/client/Http1Support.scala @@ -0,0 +1,186 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s +package blaze +package client + +import cats.effect.kernel.Async +import cats.effect.std.Dispatcher +import cats.syntax.all._ +import org.http4s.blaze.channel.ChannelOptions +import org.http4s.blaze.channel.nio2.ClientChannelFactory +import org.http4s.blaze.pipeline.Command +import org.http4s.blaze.pipeline.HeadStage +import org.http4s.blaze.pipeline.LeafBuilder +import org.http4s.blaze.pipeline.stages.SSLStage +import org.http4s.blaze.util.TickWheelExecutor +import org.http4s.blazecore.ExecutionContextConfig +import org.http4s.blazecore.IdleTimeoutStage +import org.http4s.blazecore.util.fromFutureNoShift +import org.http4s.client.ConnectionFailure +import org.http4s.client.RequestKey +import org.http4s.headers.`User-Agent` +import org.http4s.internal.SSLContextOption + +import java.net.InetSocketAddress +import java.nio.ByteBuffer +import java.nio.channels.AsynchronousChannelGroup +import javax.net.ssl.SSLContext +import scala.concurrent.ExecutionContext +import scala.concurrent.Future +import scala.concurrent.duration.Duration +import scala.concurrent.duration.FiniteDuration +import scala.util.Failure +import scala.util.Success + +/** Provides basic HTTP1 pipeline building + */ +private final class Http1Support[F[_]]( + sslContextOption: SSLContextOption, + bufferSize: Int, + asynchronousChannelGroup: Option[AsynchronousChannelGroup], + executionContextConfig: ExecutionContextConfig, + scheduler: TickWheelExecutor, + checkEndpointIdentification: Boolean, + maxResponseLineSize: Int, + maxHeaderLength: Int, + maxChunkSize: Int, + chunkBufferMaxSize: Int, + parserMode: ParserMode, + userAgent: Option[`User-Agent`], + channelOptions: ChannelOptions, + connectTimeout: Duration, + idleTimeout: Duration, + dispatcher: Dispatcher[F], + getAddress: RequestKey => Either[Throwable, InetSocketAddress], +)(implicit F: Async[F]) { + private val connectionManager = new ClientChannelFactory( + bufferSize, + asynchronousChannelGroup, + channelOptions, + scheduler, + connectTimeout, + ) + + def makeClient(requestKey: RequestKey): F[BlazeConnection[F]] = + getAddress(requestKey) match { + case Right(a) => + fromFutureNoShift( + executionContextConfig.getExecutionContext.flatMap(ec => + F.delay(buildPipeline(requestKey, a, ec)) + ) + ) + case Left(t) => F.raiseError(t) + } + + private def buildPipeline( + requestKey: RequestKey, + addr: InetSocketAddress, + executionContext: ExecutionContext, + ): Future[BlazeConnection[F]] = + connectionManager + .connect(addr) + .transformWith { + case Success(head) => + buildStages(requestKey, head, executionContext) match { + case Right(connection) => + Future.successful { + head.inboundCommand(Command.Connected) + connection + } + case Left(e) => + Future.failed(new ConnectionFailure(requestKey, addr, e)) + } + case Failure(e) => Future.failed(new ConnectionFailure(requestKey, addr, e)) + }(executionContext) + + private def buildStages( + requestKey: RequestKey, + head: HeadStage[ByteBuffer], + executionContext: ExecutionContext, + ): Either[IllegalStateException, BlazeConnection[F]] = { + + val idleTimeoutStage: Option[IdleTimeoutStage[ByteBuffer]] = makeIdleTimeoutStage( + executionContext + ) + val ssl: Either[IllegalStateException, Option[SSLStage]] = makeSslStage(requestKey) + + val connection = new Http1Connection( + requestKey = requestKey, + executionContext = executionContext, + maxResponseLineSize = maxResponseLineSize, + maxHeaderLength = maxHeaderLength, + maxChunkSize = maxChunkSize, + chunkBufferMaxSize = chunkBufferMaxSize, + parserMode = parserMode, + userAgent = userAgent, + idleTimeoutStage = idleTimeoutStage, + dispatcher = dispatcher, + ) + + ssl.map { sslStage => + val builder1 = LeafBuilder(connection) + val builder2 = idleTimeoutStage.fold(builder1)(builder1.prepend(_)) + val builder3 = sslStage.fold(builder2)(builder2.prepend(_)) + builder3.base(head) + + connection + } + } + + private def makeIdleTimeoutStage( + executionContext: ExecutionContext + ): Option[IdleTimeoutStage[ByteBuffer]] = + idleTimeout match { + case d: FiniteDuration => + Some(new IdleTimeoutStage[ByteBuffer](d, scheduler, executionContext)) + case _ => None + } + + private def makeSslStage( + requestKey: RequestKey + ): Either[IllegalStateException, Option[SSLStage]] = + requestKey match { + case RequestKey(Uri.Scheme.https, auth) => + val maybeSSLContext: Option[SSLContext] = + SSLContextOption.toMaybeSSLContext(sslContextOption) + + maybeSSLContext match { + case Some(sslContext) => + val eng = sslContext.createSSLEngine(auth.host.value, auth.port.getOrElse(443)) + eng.setUseClientMode(true) + + if (checkEndpointIdentification) { + val sslParams = eng.getSSLParameters + sslParams.setEndpointIdentificationAlgorithm("HTTPS") + eng.setSSLParameters(sslParams) + } + + Right(Some(new SSLStage(eng))) + + case None => + Left( + new IllegalStateException( + "No SSLContext configured for this client. Try `withSslContext` on the `BlazeClientBuilder`, or do not make https calls." + ) + ) + } + + case _ => + Right(None) + } +} diff --git a/blaze-client/src/main/scala/org/http4s/blaze/client/ParserMode.scala b/blaze-client/src/main/scala/org/http4s/blaze/client/ParserMode.scala new file mode 100644 index 000000000..34426c22b --- /dev/null +++ b/blaze-client/src/main/scala/org/http4s/blaze/client/ParserMode.scala @@ -0,0 +1,25 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s.blaze +package client + +sealed abstract class ParserMode extends Product with Serializable + +object ParserMode { + case object Strict extends ParserMode + case object Lenient extends ParserMode +} diff --git a/blaze-client/src/main/scala/org/http4s/blaze/client/PoolManager.scala b/blaze-client/src/main/scala/org/http4s/blaze/client/PoolManager.scala new file mode 100644 index 000000000..d2615abac --- /dev/null +++ b/blaze-client/src/main/scala/org/http4s/blaze/client/PoolManager.scala @@ -0,0 +1,453 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s +package blaze +package client + +import cats.effect._ +import cats.effect.std.Semaphore +import cats.effect.syntax.all._ +import cats.syntax.all._ +import org.http4s.client.RequestKey +import org.http4s.internal.CollectionCompat +import org.log4s.getLogger + +import java.time.Instant +import scala.collection.mutable +import scala.concurrent.ExecutionContext +import scala.concurrent.duration._ +import scala.util.Random + +final case class WaitQueueFullFailure() extends RuntimeException { + override def getMessage: String = "Wait queue is full" +} + +/** @param maxIdleDuration the maximum time a connection can be idle + * and still be borrowed + */ +private final class PoolManager[F[_], A <: Connection[F]]( + builder: ConnectionBuilder[F, A], + maxTotal: Int, + maxWaitQueueLimit: Int, + maxConnectionsPerRequestKey: RequestKey => Int, + responseHeaderTimeout: Duration, + requestTimeout: Duration, + semaphore: Semaphore[F], + implicit private val executionContext: ExecutionContext, + maxIdleDuration: Duration, +)(implicit F: Async[F]) + extends ConnectionManager.Stateful[F, A] { self => + + @deprecated("Preserved for binary compatibility", "0.23.8") + private[PoolManager] def this( + builder: ConnectionBuilder[F, A], + maxTotal: Int, + maxWaitQueueLimit: Int, + maxConnectionsPerRequestKey: RequestKey => Int, + responseHeaderTimeout: Duration, + requestTimeout: Duration, + semaphore: Semaphore[F], + executionContext: ExecutionContext, + F: Async[F], + ) = this( + builder, + maxTotal, + maxWaitQueueLimit, + maxConnectionsPerRequestKey, + responseHeaderTimeout, + requestTimeout, + semaphore, + executionContext, + Duration.Inf, + )(F) + + private sealed case class PooledConnection(conn: A, borrowDeadline: Option[Deadline]) + + private sealed case class Waiting( + key: RequestKey, + callback: Callback[NextConnection], + at: Instant, + ) + + private[this] val logger = getLogger + + private var isClosed = false + private var curTotal = 0 + private val allocated = mutable.Map.empty[RequestKey, Int] + private val idleQueues = mutable.Map.empty[RequestKey, mutable.Queue[PooledConnection]] + private var waitQueue = mutable.Queue.empty[Waiting] + + private def stats = + s"curAllocated=$curTotal idleQueues.size=${idleQueues.size} waitQueue.size=${waitQueue.size} maxWaitQueueLimit=$maxWaitQueueLimit closed=${isClosed}" + + private def getConnectionFromQueue(key: RequestKey): F[Option[PooledConnection]] = + F.delay { + idleQueues.get(key).flatMap { q => + if (q.nonEmpty) { + val pooled = q.dequeue() + if (q.isEmpty) idleQueues.remove(key) + Some(pooled) + } else None + } + } + + private def incrConnection(key: RequestKey): F[Unit] = + F.delay { + curTotal += 1 + allocated.update(key, allocated.getOrElse(key, 0) + 1) + } + + private def decrConnection(key: RequestKey): F[Unit] = + F.delay { + curTotal -= 1 + val numConnections = allocated.getOrElse(key, 0) + // If there are no more connections drop the key + if (numConnections == 1) { + allocated.remove(key) + idleQueues.remove(key) + () + } else + allocated.update(key, numConnections - 1) + } + + private def numConnectionsCheckHolds(key: RequestKey): Boolean = + curTotal < maxTotal && allocated.getOrElse(key, 0) < maxConnectionsPerRequestKey(key) + + private def isRequestExpired(t: Instant): Boolean = { + val elapsed = Instant.now().toEpochMilli - t.toEpochMilli + (requestTimeout.isFinite && elapsed >= requestTimeout.toMillis) || (responseHeaderTimeout.isFinite && elapsed >= responseHeaderTimeout.toMillis) + } + + /** This method is the core method for creating a connection which increments allocated synchronously + * then builds the connection with the given callback and completes the callback. + * + * If we can create a connection then it initially increments the allocated value within a region + * that is called synchronously by the calling method. Then it proceeds to attempt to create the connection + * and feed it the callback. If we cannot create a connection because we are already full then this + * completes the callback on the error synchronously. + * + * @param key The RequestKey for the Connection. + * @param callback The callback to complete with the NextConnection. + */ + private def createConnection(key: RequestKey, callback: Callback[NextConnection]): F[Unit] = + F.ifM(F.delay(numConnectionsCheckHolds(key)))( + incrConnection(key) *> F.start { + builder(key).attempt + .flatMap { + case Right(conn) => + F.delay(callback(Right(NextConnection(conn, fresh = true)))) + case Left(error) => + disposeConnection(key, None) *> F.delay(callback(Left(error))) + } + .evalOn(executionContext) + }.void, + addToWaitQueue(key, callback), + ) + + private def addToWaitQueue(key: RequestKey, callback: Callback[NextConnection]): F[Unit] = + F.delay { + if (waitQueue.length < maxWaitQueueLimit) { + waitQueue.enqueue(Waiting(key, callback, Instant.now())) + () + } else { + logger.error( + s"Max wait queue for limit of $maxWaitQueueLimit for $key reached, not scheduling." + ) + callback(Left(WaitQueueFullFailure())) + } + } + + private def addToIdleQueue(conn: A, key: RequestKey): F[Unit] = + F.delay { + val borrowDeadline = maxIdleDuration match { + case finite: FiniteDuration => Some(Deadline.now + finite) + case _ => None + } + val q = idleQueues.getOrElse(key, mutable.Queue.empty[PooledConnection]) + q.enqueue(PooledConnection(conn, borrowDeadline)) + idleQueues.update(key, q) + } + + /** This generates a effect of Next Connection. The following calls are executed asynchronously + * with respect to whenever the execution of this task can occur. + * + * If the pool is closed the effect failure is executed. + * + * If the pool is not closed then we look for any connections in the idleQueues that match + * the RequestKey requested. + * If a matching connection exists and it is stil open the callback is executed with the connection. + * If a matching connection is closed we deallocate and repeat the check through the idleQueues. + * If no matching connection is found, and the pool is not full we create a new Connection to perform + * the request. + * If no matching connection is found and the pool is full, and we have connections in the idleQueues + * then a connection in the idleQueues is shutdown and a new connection is created to perform the request. + * If no matching connection is found and the pool is full, and all connections are currently in use + * then the Request is placed in a waitingQueue to be executed when a connection is released. + * + * @param key The Request Key For The Connection + * @return An effect of NextConnection + */ + def borrow(key: RequestKey): F[NextConnection] = + F.async { callback => + semaphore.permit.use { _ => + if (!isClosed) { + def go(): F[Unit] = + getConnectionFromQueue(key).flatMap { + case Some(pooled) if pooled.conn.isClosed => + F.delay(logger.debug(s"Evicting closed connection for $key: $stats")) *> + decrConnection(key) *> + go() + + case Some(pooled) if pooled.borrowDeadline.exists(_.isOverdue()) => + F.delay( + logger.debug(s"Shutting down and evicting expired connection for $key: $stats") + ) *> + decrConnection(key) *> + F.delay(pooled.conn.shutdown()) *> + go() + + case Some(pooled) => + F.delay(logger.debug(s"Recycling connection for $key: $stats")) *> + F.delay(callback(Right(NextConnection(pooled.conn, fresh = false)))) + + case None if numConnectionsCheckHolds(key) => + F.delay( + logger.debug(s"Active connection not found for $key. Creating new one. $stats") + ) *> + createConnection(key, callback) + + case None if maxConnectionsPerRequestKey(key) <= 0 => + F.delay(callback(Left(NoConnectionAllowedException(key)))) + + case None if curTotal == maxTotal => + val keys = idleQueues.keys + if (keys.nonEmpty) + F.delay( + logger.debug( + s"No connections available for the desired key, $key. Evicting random and creating a new connection: $stats" + ) + ) *> + F.delay(keys.iterator.drop(Random.nextInt(keys.size)).next()).flatMap { + randKey => + getConnectionFromQueue(randKey).map( + _.fold( + logger.warn(s"No connection to evict from the idleQueue for $randKey") + )(_.conn.shutdown()) + ) *> + decrConnection(randKey) + } *> + createConnection(key, callback) + else + F.delay( + logger.debug( + s"No connections available for the desired key, $key. Adding to waitQueue: $stats" + ) + ) *> + addToWaitQueue(key, callback) + + case None => // we're full up. Add to waiting queue. + F.delay( + logger.debug( + s"No connections available for $key. Waiting on new connection: $stats" + ) + ) *> + addToWaitQueue(key, callback) + } + + F.delay(logger.debug(s"Requesting connection for $key: $stats")).productR(go()).as(None) + } else + F.delay(callback(Left(new IllegalStateException("Connection pool is closed")))).as(None) + } + } + + private def releaseRecyclable(key: RequestKey, connection: A): F[Unit] = + F.delay(waitQueue.dequeueFirst(_.key == key)).flatMap { + case Some(Waiting(_, callback, at)) => + if (isRequestExpired(at)) + F.delay(logger.debug(s"Request expired for $key")) *> + F.delay(callback(Left(WaitQueueTimeoutException))) *> + releaseRecyclable(key, connection) + else + F.delay(logger.debug(s"Fulfilling waiting connection request for $key: $stats")) *> + F.delay(callback(Right(NextConnection(connection, fresh = false)))) + + case None if waitQueue.isEmpty => + F.delay(logger.debug(s"Returning idle connection to pool for $key: $stats")) *> + addToIdleQueue(connection, key) + + case None => + findFirstAllowedWaiter.flatMap { + case Some(Waiting(k, cb, _)) => + // This is the first waiter not blocked on the request key limit. + // close the undesired connection and wait for another + F.delay(connection.shutdown()) *> + decrConnection(key) *> + createConnection(k, cb) + + case None => + // We're blocked not because of too many connections, but + // because of too many connections per key. + // We might be able to reuse this request. + addToIdleQueue(connection, key) + } + } + + private def releaseNonRecyclable(key: RequestKey, connection: A): F[Unit] = + decrConnection(key) *> + F.delay { + if (!connection.isClosed) { + logger.debug(s"Connection returned was busy for $key. Shutting down: $stats") + connection.shutdown() + } + } *> + findFirstAllowedWaiter.flatMap { + case Some(Waiting(k, callback, _)) => + F.delay( + logger + .debug( + s"Connection returned could not be recycled, new connection needed for $key: $stats" + ) + ) *> + createConnection(k, callback) + + case None => + F.delay( + logger.debug( + s"Connection could not be recycled for $key, no pending requests. Shrinking pool: $stats" + ) + ) + } + + /** This is how connections are returned to the ConnectionPool. + * + * If the pool is closed the connection is shutdown and logged. + * If it is not closed we check if the connection is recyclable. + * + * If the connection is Recyclable we check if any of the connections in the waitQueue + * are looking for the returned connections RequestKey. + * If one is the first found is given the connection.And runs it using its callback asynchronously. + * If one is not found and the waitingQueue is Empty then we place the connection on the idle queue. + * If the waiting queue is not empty and we did not find a match then we shutdown the connection + * and create a connection for the first item in the waitQueue. + * + * If it is not recyclable, and it is not shutdown we shutdown the connection. If there + * are values in the waitQueue we create a connection and execute the callback asynchronously. + * Otherwise the pool is shrunk. + * + * @param connection The connection to be released. + * @return An effect of Unit + */ + def release(connection: A): F[Unit] = { + val key = connection.requestKey + semaphore.permit.use { _ => + connection.isRecyclable + .ifM(releaseRecyclable(key, connection), releaseNonRecyclable(key, connection)) + } + } + + private def findFirstAllowedWaiter: F[Option[Waiting]] = + F.delay { + val (expired, rest) = waitQueue.span(w => isRequestExpired(w.at)) + expired.foreach(_.callback(Left(WaitQueueTimeoutException))) + if (expired.nonEmpty) { + logger.debug(s"expired requests: ${expired.length}") + waitQueue = rest + logger.debug(s"Dropped expired requests: $stats") + } + waitQueue.dequeueFirst { waiter => + allocated.getOrElse(waiter.key, 0) < maxConnectionsPerRequestKey(waiter.key) + } + } + + /** This invalidates a Connection. This is what is exposed externally, and + * is just an effect wrapper around disposing the connection. + * + * @param connection The connection to invalidate + * @return An effect of Unit + */ + override def invalidate(connection: A): F[Unit] = + semaphore.permit.use { _ => + val key = connection.requestKey + decrConnection(key) *> + F.delay(if (!connection.isClosed) connection.shutdown()) *> + findFirstAllowedWaiter.flatMap { + case Some(Waiting(k, callback, _)) => + F.delay( + logger.debug(s"Invalidated connection for $key, new connection needed: $stats") + ) *> + createConnection(k, callback) + + case None => + F.delay( + logger.debug( + s"Invalidated connection for $key, no pending requests. Shrinking pool: $stats" + ) + ) + } + } + + /** Synchronous Immediate Disposal of a Connection and Its Resources. + * + * By taking an Option of a connection this also serves as a synchronized allocated decrease. + * + * @param key The request key for the connection. Not used internally. + * @param connection An Option of a Connection to Dispose Of. + */ + private def disposeConnection(key: RequestKey, connection: Option[A]): F[Unit] = + semaphore.permit.use { _ => + F.delay(logger.debug(s"Disposing of connection for $key: $stats")) *> + decrConnection(key) *> + F.delay { + connection.foreach { s => + if (!s.isClosed) s.shutdown() + } + } + } + + /** Shuts down the connection pool permanently. + * + * Changes isClosed to true, no methods can reopen a closed Pool. + * Shutdowns all connections in the IdleQueue and Sets Allocated to Zero + * + * @return An effect Of Unit + */ + def shutdown: F[Unit] = + semaphore.permit.use { _ => + F.delay { + logger.info(s"Shutting down connection pool: $stats") + if (!isClosed) { + isClosed = true + idleQueues.foreach(_._2.foreach(_.conn.shutdown())) + idleQueues.clear() + allocated.clear() + curTotal = 0 + } + } + } + + def state: BlazeClientState[F] = + new BlazeClientState[F] { + def isClosed: F[Boolean] = F.delay(self.isClosed) + def allocated: F[Map[RequestKey, Int]] = F.delay(self.allocated.toMap) + def idleQueueDepth: F[Map[RequestKey, Int]] = + F.delay(CollectionCompat.mapValues(self.idleQueues.toMap)(_.size)) + def waitQueueDepth: F[Int] = F.delay(self.waitQueue.size) + } +} + +final case class NoConnectionAllowedException(key: RequestKey) + extends IllegalArgumentException(s"No client connections allowed to $key") diff --git a/blaze-client/src/main/scala/org/http4s/blaze/client/bits.scala b/blaze-client/src/main/scala/org/http4s/blaze/client/bits.scala new file mode 100644 index 000000000..d60f65f49 --- /dev/null +++ b/blaze-client/src/main/scala/org/http4s/blaze/client/bits.scala @@ -0,0 +1,55 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s.blaze.client + +import org.http4s.BuildInfo +import org.http4s.ProductId +import org.http4s.headers.`User-Agent` + +import java.security.SecureRandom +import java.security.cert.X509Certificate +import javax.net.ssl.SSLContext +import javax.net.ssl.X509TrustManager +import scala.concurrent.duration._ + +private[http4s] object bits { + // Some default objects + val DefaultResponseHeaderTimeout: Duration = 10.seconds + val DefaultTimeout: Duration = 60.seconds + val DefaultBufferSize: Int = 8 * 1024 + val DefaultUserAgent: Option[`User-Agent`] = Some( + `User-Agent`(ProductId("http4s-blaze", Some(BuildInfo.version))) + ) + val DefaultMaxTotalConnections = 10 + val DefaultMaxWaitQueueLimit = 256 + + /** Caution: trusts all certificates and disables endpoint identification */ + @deprecated( + "Kept for binary compatibility. Unfit for production. Embeds a blocking call on some platforms.", + "0.23.13", + ) + lazy val TrustingSslContext: SSLContext = { + val trustManager = new X509TrustManager { + def getAcceptedIssuers(): Array[X509Certificate] = Array.empty + def checkClientTrusted(certs: Array[X509Certificate], authType: String): Unit = {} + def checkServerTrusted(certs: Array[X509Certificate], authType: String): Unit = {} + } + val sslContext = SSLContext.getInstance("TLS") + sslContext.init(null, Array(trustManager), new SecureRandom) + sslContext + } +} diff --git a/blaze-client/src/main/scala/org/http4s/client/blaze/package.scala b/blaze-client/src/main/scala/org/http4s/client/blaze/package.scala new file mode 100644 index 000000000..75c95db0f --- /dev/null +++ b/blaze-client/src/main/scala/org/http4s/client/blaze/package.scala @@ -0,0 +1,25 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s.client + +package object blaze { + @deprecated("use org.http4s.blaze.client.BlazeClientBuilder", "0.22") + type BlazeClientBuilder[F[_]] = org.http4s.blaze.client.BlazeClientBuilder[F] + + @deprecated("use org.http4s.blaze.client.BlazeClientBuilder", "0.22") + val BlazeClientBuilder = org.http4s.blaze.client.BlazeClientBuilder +} diff --git a/blaze-client/src/test/scala-2.13/org/http4s/client/blaze/BlazeClient213Suite.scala b/blaze-client/src/test/scala-2.13/org/http4s/client/blaze/BlazeClient213Suite.scala new file mode 100644 index 000000000..db2ce5aff --- /dev/null +++ b/blaze-client/src/test/scala-2.13/org/http4s/client/blaze/BlazeClient213Suite.scala @@ -0,0 +1,168 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s.client +package blaze + +import cats.effect._ +import cats.syntax.all._ +import fs2.Stream +import org.http4s._ +import org.http4s.blaze.client.BlazeClientBase + +import java.util.concurrent.TimeUnit +import scala.concurrent.duration._ +import scala.util.Random + +class BlazeClient213Suite extends BlazeClientBase { + override def munitTimeout: Duration = new FiniteDuration(50, TimeUnit.SECONDS) + + test("reset request timeout".flaky) { + val addresses = server().addresses + val address = addresses.head + val name = address.host + val port = address.port + + Ref[IO] + .of(0L) + .flatMap { _ => + builder(1, requestTimeout = 2.second).resource.use { client => + val submit = + client.status(Request[IO](uri = Uri.fromString(s"http://$name:$port/simple").yolo)) + submit *> IO.sleep(3.seconds) *> submit + } + } + .assertEquals(Status.Ok) + } + + test("Blaze Http1Client should behave and not deadlock") { + val addresses = server().addresses + val hosts = addresses.map { address => + val name = address.host + val port = address.port + Uri.fromString(s"http://$name:$port/simple").yolo + } + + builder(3).resource + .use { client => + (1 to Runtime.getRuntime.availableProcessors * 5).toList + .parTraverse { _ => + val h = hosts(Random.nextInt(hosts.length)) + client.expect[String](h).map(_.nonEmpty) + } + .map(_.forall(identity)) + } + .assertEquals(true) + } + + test("behave and not deadlock on failures with parTraverse") { + val addresses = server().addresses + builder(3).resource + .use { client => + val failedHosts = addresses.map { address => + val name = address.host + val port = address.port + Uri.fromString(s"http://$name:$port/internal-server-error").yolo + } + + val successHosts = addresses.map { address => + val name = address.host + val port = address.port + Uri.fromString(s"http://$name:$port/simple").yolo + } + + val failedRequests = + (1 to Runtime.getRuntime.availableProcessors * 5).toList.parTraverse { _ => + val h = failedHosts(Random.nextInt(failedHosts.length)) + client.expect[String](h) + } + + val sucessRequests = + (1 to Runtime.getRuntime.availableProcessors * 5).toList.parTraverse { _ => + val h = successHosts(Random.nextInt(successHosts.length)) + client.expect[String](h).map(_.nonEmpty) + } + + val allRequests = for { + _ <- failedRequests.handleErrorWith(_ => IO.unit).replicateA(5) + r <- sucessRequests + } yield r + + allRequests + .map(_.forall(identity)) + } + .assertEquals(true) + } + + test("Blaze Http1Client should behave and not deadlock on failures with parSequence".flaky) { + val addresses = server().addresses + builder(3).resource + .use { client => + val failedHosts = addresses.map { address => + val name = address.host + val port = address.port + Uri.fromString(s"http://$name:$port/internal-server-error").yolo + } + + val successHosts = addresses.map { address => + val name = address.host + val port = address.port + Uri.fromString(s"http://$name:$port/simple").yolo + } + + val failedRequests = (1 to Runtime.getRuntime.availableProcessors * 5).toList.map { _ => + val h = failedHosts(Random.nextInt(failedHosts.length)) + client.expect[String](h) + }.parSequence + + val sucessRequests = (1 to Runtime.getRuntime.availableProcessors * 5).toList.map { _ => + val h = successHosts(Random.nextInt(successHosts.length)) + client.expect[String](h).map(_.nonEmpty) + }.parSequence + + val allRequests = for { + _ <- failedRequests.handleErrorWith(_ => IO.unit).replicateA(5) + r <- sucessRequests + } yield r + + allRequests + .map(_.forall(identity)) + } + .assertEquals(true) + } + + test("call a second host after reusing connections on a first") { + val addresses = server().addresses + // https://github.com/http4s/http4s/pull/2546 + builder(maxConnectionsPerRequestKey = Int.MaxValue, maxTotalConnections = 5).resource + .use { client => + val uris = addresses.take(2).map { address => + val name = address.host + val port = address.port + Uri.fromString(s"http://$name:$port/simple").yolo + } + val s = Stream( + Stream.eval( + client.expect[String](Request[IO](uri = uris(0))) + ) + ).repeat.take(10).parJoinUnbounded ++ Stream.eval( + client.expect[String](Request[IO](uri = uris(1))) + ) + s.compile.lastOrError + } + .assertEquals("simple path") + } +} diff --git a/blaze-client/src/test/scala/org/http4s/blaze/client/BlazeClientBase.scala b/blaze-client/src/test/scala/org/http4s/blaze/client/BlazeClientBase.scala new file mode 100644 index 000000000..9cf9d3283 --- /dev/null +++ b/blaze-client/src/test/scala/org/http4s/blaze/client/BlazeClientBase.scala @@ -0,0 +1,150 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s.blaze +package client + +import cats.effect._ +import cats.effect.kernel.Resource +import cats.effect.std.Dispatcher +import cats.syntax.all._ +import cats.implicits.catsSyntaxApplicativeId +import fs2.Stream +import io.netty.channel.ChannelHandlerContext +import io.netty.handler.codec.http.HttpMethod +import io.netty.handler.codec.http.HttpRequest +import io.netty.handler.codec.http.HttpResponseStatus +import munit.CatsEffectSuite +import org.http4s.Status.Ok +import org.http4s._ +import org.http4s.blaze.util.TickWheelExecutor +import org.http4s.client.testkit.scaffold._ +import org.http4s.client.testkit.testroutes.GetRoutes +import org.http4s.dsl.io._ + +import java.security.SecureRandom +import java.security.cert.X509Certificate +import javax.net.ssl.SSLContext +import javax.net.ssl.X509TrustManager +import scala.concurrent.duration._ + +trait BlazeClientBase extends CatsEffectSuite { + val tickWheel: TickWheelExecutor = new TickWheelExecutor(tick = 50.millis) + + val TrustingSslContext: IO[SSLContext] = IO.blocking { + val trustManager = new X509TrustManager { + def getAcceptedIssuers(): Array[X509Certificate] = Array.empty + def checkClientTrusted(certs: Array[X509Certificate], authType: String): Unit = {} + def checkServerTrusted(certs: Array[X509Certificate], authType: String): Unit = {} + } + val ctx = SSLContext.getInstance("TLS") + ctx.init(null, Array(trustManager), new SecureRandom()) + ctx + } + + def builder( + maxConnectionsPerRequestKey: Int, + maxTotalConnections: Int = 5, + responseHeaderTimeout: Duration = 30.seconds, + requestTimeout: Duration = 45.seconds, + chunkBufferMaxSize: Int = 1024, + sslContextOption: Option[SSLContext] = None, + retries: Int = 0, + ): BlazeClientBuilder[IO] = { + val builder: BlazeClientBuilder[IO] = + BlazeClientBuilder[IO] + .withCheckEndpointAuthentication(false) + .withResponseHeaderTimeout(responseHeaderTimeout) + .withRequestTimeout(requestTimeout) + .withMaxTotalConnections(maxTotalConnections) + .withMaxConnectionsPerRequestKey(Function.const(maxConnectionsPerRequestKey)) + .withChunkBufferMaxSize(chunkBufferMaxSize) + .withScheduler(scheduler = tickWheel) + .withRetries(retries) + + sslContextOption.fold[BlazeClientBuilder[IO]](builder.withoutSslContext)(builder.withSslContext) + } + + private def makeScaffold(num: Int, secure: Boolean): Resource[IO, ServerScaffold[IO]] = + for { + dispatcher <- Dispatcher[IO] + getHandler <- Resource.eval( + RoutesToHandlerAdapter( + HttpRoutes.of[IO] { + case Method.GET -> Root / "infinite" => + Response[IO](Ok).withEntity(Stream.emit[IO, String]("a" * 8 * 1024).repeat).pure[IO] + + case _ @(Method.GET -> path) => + GetRoutes.getPaths.getOrElse(path.toString, NotFound()) + }, + dispatcher, + ) + ) + scaffold <- ServerScaffold[IO]( + num, + secure, + HandlersToNettyAdapter[IO](postHandlers, getHandler), + ) + } yield scaffold + + private def postHandlers: Map[(HttpMethod, String), Handler] = + Map( + (HttpMethod.POST, "/respond-and-close-immediately") -> new Handler { + // The client may receive the response before sending the whole request + override def onRequestStart(ctx: ChannelHandlerContext, request: HttpRequest): Unit = { + HandlerHelpers.sendResponse( + ctx, + HttpResponseStatus.OK, + HandlerHelpers.utf8Text("a"), + closeConnection = true, + ) + () + } + + override def onRequestEnd(ctx: ChannelHandlerContext, request: HttpRequest): Unit = () + }, + (HttpMethod.POST, "/respond-and-close-immediately-no-body") -> new Handler { + // The client may receive the response before sending the whole request + override def onRequestStart(ctx: ChannelHandlerContext, request: HttpRequest): Unit = { + HandlerHelpers.sendResponse(ctx, HttpResponseStatus.OK, closeConnection = true) + () + } + + override def onRequestEnd(ctx: ChannelHandlerContext, request: HttpRequest): Unit = () + }, + (HttpMethod.POST, "/process-request-entity") -> new Handler { + // We wait for the entire request to arrive before sending a response. That's how servers normally behave. + override def onRequestEnd(ctx: ChannelHandlerContext, request: HttpRequest): Unit = { + HandlerHelpers.sendResponse(ctx, HttpResponseStatus.OK) + () + } + }, + ) + + val server: Fixture[ServerScaffold[IO]] = + ResourceSuiteLocalFixture("http", makeScaffold(2, false)) + val secureServer: Fixture[ServerScaffold[IO]] = + ResourceSuiteLocalFixture("https", makeScaffold(1, true)) + + override val munitFixtures = List( + server, + secureServer, + ) + + implicit class ParseResultSyntax[A](self: ParseResult[A]) { + def yolo: A = self.valueOr(e => sys.error(e.toString)) + } +} diff --git a/blaze-client/src/test/scala/org/http4s/blaze/client/BlazeClientBuilderSuite.scala b/blaze-client/src/test/scala/org/http4s/blaze/client/BlazeClientBuilderSuite.scala new file mode 100644 index 000000000..df5c58602 --- /dev/null +++ b/blaze-client/src/test/scala/org/http4s/blaze/client/BlazeClientBuilderSuite.scala @@ -0,0 +1,103 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s +package blaze +package client + +import cats.effect.IO +import munit.CatsEffectSuite +import org.http4s.blaze.channel.ChannelOptions + +class BlazeClientBuilderSuite extends CatsEffectSuite { + private def builder = BlazeClientBuilder[IO] + + test("default to empty") { + assertEquals(builder.channelOptions, ChannelOptions(Vector.empty)) + } + + test("set socket send buffer size") { + assertEquals(builder.withSocketSendBufferSize(8192).socketSendBufferSize, Some(8192)) + } + + test("set socket receive buffer size") { + assertEquals(builder.withSocketReceiveBufferSize(8192).socketReceiveBufferSize, Some(8192)) + } + + test("set socket keepalive") { + assertEquals(builder.withSocketKeepAlive(true).socketKeepAlive, Some(true)) + } + + test("set socket reuse address") { + assertEquals(builder.withSocketReuseAddress(true).socketReuseAddress, Some(true)) + } + + test("set TCP nodelay") { + assertEquals(builder.withTcpNoDelay(true).tcpNoDelay, Some(true)) + } + + test("unset socket send buffer size") { + assertEquals( + builder + .withSocketSendBufferSize(8192) + .withDefaultSocketSendBufferSize + .socketSendBufferSize, + None, + ) + } + + test("unset socket receive buffer size") { + assertEquals( + builder + .withSocketReceiveBufferSize(8192) + .withDefaultSocketReceiveBufferSize + .socketReceiveBufferSize, + None, + ) + } + + test("unset socket keepalive") { + assertEquals(builder.withSocketKeepAlive(true).withDefaultSocketKeepAlive.socketKeepAlive, None) + } + + test("unset socket reuse address") { + assertEquals( + builder + .withSocketReuseAddress(true) + .withDefaultSocketReuseAddress + .socketReuseAddress, + None, + ) + } + + test("unset TCP nodelay") { + assertEquals(builder.withTcpNoDelay(true).withDefaultTcpNoDelay.tcpNoDelay, None) + } + + test("overwrite keys") { + assertEquals( + builder + .withSocketSendBufferSize(8192) + .withSocketSendBufferSize(4096) + .socketSendBufferSize, + Some(4096), + ) + } + + test("set header max length") { + assertEquals(builder.withMaxHeaderLength(64).maxHeaderLength, 64) + } +} diff --git a/blaze-client/src/test/scala/org/http4s/blaze/client/BlazeClientConnectionReuseSuite.scala b/blaze-client/src/test/scala/org/http4s/blaze/client/BlazeClientConnectionReuseSuite.scala new file mode 100644 index 000000000..e939c6c36 --- /dev/null +++ b/blaze-client/src/test/scala/org/http4s/blaze/client/BlazeClientConnectionReuseSuite.scala @@ -0,0 +1,247 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s.blaze +package client + +import cats.effect._ +import cats.implicits._ +import fs2.Chunk +import fs2.Stream +import org.http4s.Method._ +import org.http4s._ +import org.http4s.client.testkit.scaffold.TestServer + +import java.net.SocketException +import java.util.concurrent.TimeUnit +import scala.concurrent.duration._ + +class BlazeClientConnectionReuseSuite extends BlazeClientBase { + override def munitTimeout: Duration = new FiniteDuration(50, TimeUnit.SECONDS) + + test("BlazeClient should reuse the connection after a simple successful request") { + builder().resource.use { client => + for { + servers <- makeServers() + _ <- client.expect[String](Request[IO](GET, servers(0).uri / "simple")) + _ <- client.expect[String](Request[IO](GET, servers(0).uri / "simple")) + _ <- servers(0).establishedConnections.assertEquals(1L) + } yield () + } + } + + test("BlazeClient should reuse the connection after a successful request with large response") { + builder().resource.use { client => + for { + servers <- makeServers() + _ <- client.expect[String](Request[IO](GET, servers(0).uri / "large")) + _ <- client.expect[String](Request[IO](GET, servers(0).uri / "simple")) + _ <- servers(0).establishedConnections.assertEquals(1L) + } yield () + } + } + + test( + "BlazeClient.status should reuse the connection after receiving a response without an entity" + ) { + builder().resource.use { client => + for { + servers <- makeServers() + _ <- client.status(Request[IO](GET, servers(0).uri / "no-content")) + _ <- client.expect[String](Request[IO](GET, servers(0).uri / "simple")) + _ <- servers(0).establishedConnections.assertEquals(1L) + } yield () + } + } + + // BlazeClient.status may or may not reuse the connection after receiving a response with an entity. + // It's up to the implementation. + // The connection can be reused only if the entity has been fully read from the socket. + // The current BlazeClient implementation will reuse the connection if it read the entire entity while reading the status line and headers. + // This behaviour depends on `BlazeClientBuilder.bufferSize`. + // In particular, responses not bigger than `bufferSize` will lead to reuse of the connection. + + test( + "BlazeClient.status shouldn't wait for an infinite response entity and shouldn't reuse the connection" + ) { + builder().resource.use { client => + for { + servers <- makeServers() + _ <- client + .status(Request[IO](GET, servers(0).uri / "infinite")) + .timeout(5.seconds) // we expect it to complete without waiting for the response body + _ <- client.expect[String](Request[IO](GET, servers(0).uri / "simple")) + _ <- servers(0).establishedConnections.assertEquals(2L) + } yield () + } + } + + test("BlazeClient should reuse connections to different servers separately") { + builder().resource.use { client => + for { + servers <- makeServers() + _ <- client.expect[String](Request[IO](GET, servers(0).uri / "simple")) + _ <- client.expect[String](Request[IO](GET, servers(0).uri / "simple")) + _ <- servers(0).establishedConnections.assertEquals(1L) + _ <- servers(1).establishedConnections.assertEquals(0L) + _ <- client.expect[String](Request[IO](GET, servers(1).uri / "simple")) + _ <- client.expect[String](Request[IO](GET, servers(1).uri / "simple")) + _ <- servers(0).establishedConnections.assertEquals(1L) + _ <- servers(1).establishedConnections.assertEquals(1L) + } yield () + } + } + + // // Decoding failures // // + + test("BlazeClient should reuse the connection after response decoding failed") { + // This will work regardless of whether we drain the entity or not, + // because the response is small and it is read in full in first read operation + val drainThenFail = EntityDecoder.error[IO, String](new Exception()) + builder().resource.use { client => + for { + servers <- makeServers() + _ <- client + .expect[String](Request[IO](GET, servers(0).uri / "simple"))(drainThenFail) + .attempt + _ <- client.expect[String](Request[IO](GET, servers(0).uri / "simple")) + _ <- servers(0).establishedConnections.assertEquals(1L) + } yield () + } + } + + test( + "BlazeClient should reuse the connection after response decoding failed and the (large) entity was drained" + ) { + val drainThenFail = EntityDecoder.error[IO, String](new Exception()) + builder().resource.use { client => + for { + servers <- makeServers() + _ <- client + .expect[String](Request[IO](GET, servers(0).uri / "large"))(drainThenFail) + .attempt + _ <- client.expect[String](Request[IO](GET, servers(0).uri / "simple")) + _ <- servers(0).establishedConnections.assertEquals(1L) + } yield () + } + } + + test( + "BlazeClient shouldn't reuse the connection after response decoding failed and the (large) entity wasn't drained" + ) { + val failWithoutDraining = new EntityDecoder[IO, String] { + override def decode(m: Media[IO], strict: Boolean): DecodeResult[IO, String] = + DecodeResult[IO, String](IO.raiseError(new Exception())) + override def consumes: Set[MediaRange] = Set.empty + } + builder().resource.use { client => + for { + servers <- makeServers() + _ <- client + .expect[String](Request[IO](GET, servers(0).uri / "large"))(failWithoutDraining) + .attempt + _ <- client.expect[String](Request[IO](GET, servers(0).uri / "simple")) + _ <- servers(0).establishedConnections.assertEquals(2L) + } yield () + } + } + + // // Requests with an entity // // + + test("BlazeClient should reuse the connection after a request with an entity") { + builder().resource.use { client => + for { + servers <- makeServers() + _ <- client.expect[String]( + Request[IO](POST, servers(0).uri / "process-request-entity").withEntity("entity") + ) + _ <- client.expect[String](Request[IO](GET, servers(0).uri / "simple")) + _ <- servers(0).establishedConnections.assertEquals(1L) + } yield () + } + } + + // TODO investigate delay in sending first chunk (it waits for 2 complete 32kB chunks) + + test( + "BlazeClient shouldn't wait for the request entity transfer to complete if the server closed the connection early. The closed connection shouldn't be reused.".flaky + ) { + builder().resource.use { client => + for { + servers <- makeServers() + // In a typical execution of this test the server receives the beginning of the requests, responds, and then sends FIN. The client processes the response, processes the FIN, and then stops sending the request entity. The server may then send an RST, but if the client already processed the response then it's not a problem. + // But sometimes the server receives the beginning of the request, responds, sends FIN and then sends RST. There's a race between delivering the response to the application and acting on the RST, that is closing the socket and delivering an EOF. That means that the request may fail with "SocketException: HTTP connection closed". + // I don't know how to prevent the second scenario. So instead I relaxed the requirement expressed in this test to accept both successes and SocketExceptions, and only require timely completion of the request, and disposal of the connection. + _ <- client + .expect[String]( + Request[IO](POST, servers(0).uri / "respond-and-close-immediately") + .withBodyStream( + Stream + .fixedDelay[IO](10.milliseconds) + .mapChunks(_ => Chunk.array(Array.fill(10000)(0.toByte))) + ) + ) + .recover { case _: SocketException => "" } + .timeout(2.seconds) + _ <- client.expect[String](Request[IO](GET, servers(0).uri / "simple")) + _ <- servers(0).establishedConnections.assertEquals(2L) + } yield () + } + } + + // // Load tests // // + + test( + "BlazeClient should keep reusing connections even when under heavy load (single client scenario)" + ) { + builder().resource.use { client => + for { + servers <- makeServers() + _ <- client + .expect[String](Request[IO](GET, servers(0).uri / "simple")) + .replicateA(200) + .parReplicateA(20) + // There's no guarantee we'll actually manage to use 20 connections in parallel. Sharing the client means sharing the lock inside PoolManager as a contention point. + // But if the connections are reused correctly, we shouldn't use more than 20. + _ <- servers(0).establishedConnections.map(_ <= 20L).assert + } yield () + } + } + + test( + "BlazeClient should keep reusing connections even when under heavy load (multiple clients scenario)" + ) { + for { + servers <- makeServers() + _ <- builder().resource + .use { client => + client.expect[String](Request[IO](GET, servers(0).uri / "simple")).replicateA(400) + } + .parReplicateA(20) + _ <- servers(0).establishedConnections.assertEquals(20L) + } yield () + } + + private def builder(): BlazeClientBuilder[IO] = + BlazeClientBuilder[IO].withScheduler(scheduler = tickWheel) + + private def makeServers(): IO[Vector[TestServer[IO]]] = { + val testServers = server().servers + testServers + .traverse(_.resetEstablishedConnections) + .as(testServers) + } +} diff --git a/blaze-client/src/test/scala/org/http4s/blaze/client/BlazeClientSuite.scala b/blaze-client/src/test/scala/org/http4s/blaze/client/BlazeClientSuite.scala new file mode 100644 index 000000000..75bb7c412 --- /dev/null +++ b/blaze-client/src/test/scala/org/http4s/blaze/client/BlazeClientSuite.scala @@ -0,0 +1,380 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s +package blaze +package client + +import cats.effect._ +import cats.syntax.all._ +import fs2.Stream +import fs2.io.net.Network +import io.netty.channel.ChannelHandlerContext +import io.netty.handler.codec.http.HttpMethod +import io.netty.handler.codec.http.HttpRequest +import io.netty.handler.codec.http.HttpResponseStatus +import org.http4s.client.ConnectionFailure +import org.http4s.client.RequestKey +import org.http4s.client.testkit.scaffold.Handler +import org.http4s.client.testkit.scaffold.HandlerHelpers +import org.http4s.client.testkit.scaffold.HandlersToNettyAdapter +import org.http4s.client.testkit.scaffold.ServerScaffold +import org.http4s.syntax.all._ + +import java.net.SocketException +import java.util.concurrent.TimeoutException +import scala.concurrent.duration._ + +class BlazeClientSuite extends BlazeClientBase { + + test( + "Blaze Http1Client should raise error NoConnectionAllowedException if no connections are permitted for key" + ) { + val sslAddress = secureServer().addresses.head + val name = sslAddress.host + val port = sslAddress.port + val u = Uri.fromString(s"https://$name:$port/simple").yolo + val resp = builder(0).resource.use(_.expect[String](u).attempt) + resp.assertEquals(Left(NoConnectionAllowedException(RequestKey(u.scheme.get, u.authority.get)))) + } + + test("Blaze Http1Client should make simple https requests") { + val sslAddress = secureServer().addresses.head + val name = sslAddress.host + val port = sslAddress.port + val u = Uri.fromString(s"https://$name:$port/simple").yolo + TrustingSslContext + .flatMap { ctx => + val resp = builder(1, sslContextOption = Some(ctx)).resource.use(_.expect[String](u)) + resp.map(_.length > 0) + } + .assertEquals(true) + } + + test("Blaze Http1Client should reject https requests when no SSLContext is configured") { + val sslAddress = secureServer().addresses.head + val name = sslAddress.host + val port = sslAddress.port + val u = Uri.fromString(s"https://$name:$port/simple").yolo + val resp = builder(1, sslContextOption = None).resource + .use(_.expect[String](u)) + .attempt + resp + .map { + case Left(_: ConnectionFailure) => true + case _ => false + } + .assertEquals(true) + } + + test("Blaze Http1Client should obey response header timeout") { + val addresses = server().addresses + val address = addresses(0) + val name = address.host + val port = address.port + builder(1, responseHeaderTimeout = 100.millis).resource + .use { client => + val submit = client.expect[String](Uri.fromString(s"http://$name:$port/delayed").yolo) + submit + } + .intercept[TimeoutException] + } + + test("Blaze Http1Client should unblock waiting connections") { + val addresses = server().addresses + val address = addresses(0) + val name = address.host + val port = address.port + builder(1, responseHeaderTimeout = 20.seconds).resource + .use { client => + val submit = client.expect[String](Uri.fromString(s"http://$name:$port/delayed").yolo) + for { + _ <- submit.start + r <- submit.attempt + } yield r + } + .map(_.isRight) + .assertEquals(true) + } + + test("Blaze Http1Client should drain waiting connections after shutdown") { + val addresses = server().addresses + val address = addresses(0) + val name = address.host + val port = address.port + + val resp = builder(1, responseHeaderTimeout = 20.seconds).resource + .use { drainTestClient => + drainTestClient + .expect[String](Uri.fromString(s"http://$name:$port/delayed").yolo) + .attempt + .start + + val resp = drainTestClient + .expect[String](Uri.fromString(s"http://$name:$port/delayed").yolo) + .attempt + .map(_.exists(_.nonEmpty)) + .start + + // Wait 100 millis to shut down + IO.sleep(100.millis) *> resp.flatMap(_.joinWithNever) + } + + resp.assertEquals(true) + } + + test( + "Blaze Http1Client should stop sending data when the server sends response and closes connection" + ) { + // https://datatracker.ietf.org/doc/html/rfc2616#section-8.2.2 + val addresses = server().addresses + val address = addresses.head + val name = address.host + val port = address.port + Deferred[IO, Unit] + .flatMap { reqClosed => + builder(1, requestTimeout = 60.seconds).resource.use { client => + val body = Stream(0.toByte).repeat.onFinalizeWeak(reqClosed.complete(()).void) + val req = Request[IO]( + method = Method.POST, + uri = Uri.fromString(s"http://$name:$port/respond-and-close-immediately").yolo, + ).withBodyStream(body) + client.status(req) >> reqClosed.get + } + } + .assertEquals(()) + } + + test( + "Blaze Http1Client should stop sending data when the server sends response without body and closes connection" + ) { + // https://datatracker.ietf.org/doc/html/rfc2616#section-8.2.2 + // Receiving a response with and without body exercises different execution path in blaze client. + + val addresses = server().addresses + val address = addresses.head + val name = address.host + val port = address.port + Deferred[IO, Unit] + .flatMap { reqClosed => + builder(1, requestTimeout = 60.seconds).resource.use { client => + val body = Stream(0.toByte).repeat.onFinalizeWeak(reqClosed.complete(()).void) + val req = Request[IO]( + method = Method.POST, + uri = Uri.fromString(s"http://$name:$port/respond-and-close-immediately-no-body").yolo, + ).withBodyStream(body) + client.status(req) >> reqClosed.get + } + } + .assertEquals(()) + } + + test( + "Blaze Http1Client should fail with request timeout if the request body takes too long to send" + ) { + val addresses = server().addresses + val address = addresses.head + val name = address.host + val port = address.port + builder(1, requestTimeout = 500.millis, responseHeaderTimeout = Duration.Inf).resource + .use { client => + val body = Stream(0.toByte).repeat + val req = Request[IO]( + method = Method.POST, + uri = Uri.fromString(s"http://$name:$port/process-request-entity").yolo, + ).withBodyStream(body) + client.status(req) + } + .attempt + .map { + case Left(_: TimeoutException) => true + case _ => false + } + .assert + } + + test( + "Blaze Http1Client should fail with response header timeout if the request body takes too long to send" + ) { + val addresses = server().addresses + val address = addresses.head + val name = address.host + val port = address.port + builder(1, requestTimeout = Duration.Inf, responseHeaderTimeout = 500.millis).resource + .use { client => + val body = Stream(0.toByte).repeat + val req = Request[IO]( + method = Method.POST, + uri = Uri.fromString(s"http://$name:$port/process-request-entity").yolo, + ).withBodyStream(body) + client.status(req) + } + .attempt + .map { + case Left(_: TimeoutException) => true + case _ => false + } + .assert + } + + test("Blaze Http1Client should doesn't leak connection on timeout".flaky) { + val addresses = server().addresses + val address = addresses.head + val name = address.host + val port = address.port + val uri = Uri.fromString(s"http://$name:$port/simple").yolo + + builder(1).resource + .use { client => + val req = Request[IO](uri = uri) + client + .run(req) + .use { _ => + IO.never + } + .timeout(250.millis) + .attempt >> + client.status(req) + } + .assertEquals(Status.Ok) + } + + test("Blaze Http1Client should raise a ConnectionFailure when a host can't be resolved") { + builder(1).resource + .use { client => + client.status(Request[IO](uri = uri"http://example.invalid/")) + } + .interceptMessage[ConnectionFailure]( + "Error connecting to http://example.invalid using address example.invalid:80 (unresolved: true)" + ) + } + + test("Blaze HTTP/1 client should raise a ResponseException when it receives an unexpected EOF") { + Network[IO] + .serverResource(address = None, port = None, options = Nil) + .map { case (addr, sockets) => + val uri = Uri.fromString(s"http://[${addr.host}]:${addr.port}/eof").yolo + val req = Request[IO](uri = uri) + (req, sockets) + } + .use { case (req, sockets) => + Stream + .eval(builder(1).resource.use { client => + interceptMessageIO[SocketException]( + s"HTTP connection closed: ${RequestKey.fromRequest(req)}" + )(client.expect[String](req)) + }) + .concurrently(sockets.evalMap(s => s.endOfInput *> s.endOfOutput)) + .compile + .drain + } + } + + test("Keeps stats".flaky) { + val addresses = server().addresses + val address = addresses.head + val name = address.host + val port = address.port + val uri = Uri.fromString(s"http://$name:$port/process-request-entity").yolo + builder(1, requestTimeout = 2.seconds).resourceWithState.use { case (client, state) => + for { + // We're not thoroughly exercising the pool stats. We're doing a rudimentary check. + _ <- state.allocated.assertEquals(Map.empty[RequestKey, Int]) + reading <- Deferred[IO, Unit] + done <- Deferred[IO, Unit] + body = Stream.eval(reading.complete(())) *> (Stream.empty: EntityBody[IO]) <* Stream.eval( + done.get + ) + req = Request[IO](Method.POST, uri = uri).withEntity(body) + _ <- client.status(req).start + _ <- reading.get + _ <- state.allocated.map(_.get(RequestKey.fromRequest(req))).assertEquals(Some(1)) + _ <- done.complete(()) + } yield () + } + } + + test("retries idempotent requests") { + Ref.of[IO, Int](0).flatMap { attempts => + val handlers = Map((HttpMethod.GET, "/close-without-response") -> new Handler { + override def onRequestStart(ctx: ChannelHandlerContext, request: HttpRequest): Unit = { + attempts.update(_ + 1).unsafeRunSync() + ctx.channel.close() + () + } + override def onRequestEnd(ctx: ChannelHandlerContext, request: HttpRequest): Unit = () + }) + ServerScaffold[IO](1, false, HandlersToNettyAdapter[IO](handlers)).use { server => + val address = server.addresses.head + val name = address.host + val port = address.port + val uri = Uri.fromString(s"http://$name:$port/close-without-response").yolo + val req = Request[IO](method = Method.GET, uri = uri) + val key = RequestKey.fromRequest(req) + builder(1, retries = 3).resourceWithState + .use { case (client, state) => + client.status(req).attempt *> attempts.get.assertEquals(4) *> + state.allocated.map(_.get(key)).assertEquals(None) + } + } + } + } + + test("does not retry non-idempotent requests") { + Ref.of[IO, Int](0).flatMap { attempts => + val handlers = Map((HttpMethod.POST, "/close-without-response") -> new Handler { + override def onRequestStart(ctx: ChannelHandlerContext, request: HttpRequest): Unit = { + attempts.update(_ + 1).unsafeRunSync() + ctx.channel.close() + () + } + override def onRequestEnd(ctx: ChannelHandlerContext, request: HttpRequest): Unit = () + }) + ServerScaffold[IO](1, false, HandlersToNettyAdapter[IO](handlers)).use { server => + val address = server.addresses.head + val name = address.host + val port = address.port + val uri = Uri.fromString(s"http://$name:$port/close-without-response").yolo + val req = Request[IO](method = Method.POST, uri = uri) + builder(1, retries = 3).resource + .use(client => client.status(req).attempt *> attempts.get.assertEquals(1)) + } + } + } + + test("does not retry requests that fail without a SocketException") { + Ref.of[IO, Int](0).flatMap { attempts => + val handlers = Map((HttpMethod.GET, "/500") -> new Handler { + override def onRequestStart(ctx: ChannelHandlerContext, request: HttpRequest): Unit = { + attempts.update(_ + 1).unsafeRunSync() + HandlerHelpers + .sendResponse(ctx, HttpResponseStatus.INTERNAL_SERVER_ERROR, closeConnection = true) + () + } + override def onRequestEnd(ctx: ChannelHandlerContext, request: HttpRequest): Unit = () + }) + ServerScaffold[IO](1, false, HandlersToNettyAdapter[IO](handlers)).use { server => + val address = server.addresses.head + val name = address.host + val port = address.port + val uri = Uri.fromString(s"http://$name:$port/500").yolo + val req = Request[IO](method = Method.GET, uri = uri) + builder(1, retries = 3).resource + .use(client => client.status(req).attempt *> attempts.get.assertEquals(1)) + } + } + } +} diff --git a/blaze-client/src/test/scala/org/http4s/blaze/client/BlazeHttp1ClientSuite.scala b/blaze-client/src/test/scala/org/http4s/blaze/client/BlazeHttp1ClientSuite.scala new file mode 100644 index 000000000..8ab2e8c98 --- /dev/null +++ b/blaze-client/src/test/scala/org/http4s/blaze/client/BlazeHttp1ClientSuite.scala @@ -0,0 +1,29 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s +package blaze +package client + +import cats.effect.IO +import cats.effect.Resource +import org.http4s.client.Client +import org.http4s.client.testkit.ClientRouteTestBattery + +class BlazeHttp1ClientSuite extends ClientRouteTestBattery("BlazeClient") { + def clientResource: Resource[IO, Client[IO]] = + BlazeClientBuilder[IO].resource +} diff --git a/blaze-client/src/test/scala/org/http4s/blaze/client/ClientTimeoutSuite.scala b/blaze-client/src/test/scala/org/http4s/blaze/client/ClientTimeoutSuite.scala new file mode 100644 index 000000000..8eda423f6 --- /dev/null +++ b/blaze-client/src/test/scala/org/http4s/blaze/client/ClientTimeoutSuite.scala @@ -0,0 +1,246 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s +package blaze +package client + +import cats.effect._ +import cats.effect.std.Dispatcher +import cats.effect.std.Queue +import cats.syntax.all._ +import fs2.Chunk +import fs2.Stream +import munit.CatsEffectSuite +import org.http4s.blaze.pipeline.HeadStage +import org.http4s.blaze.pipeline.LeafBuilder +import org.http4s.blaze.util.TickWheelExecutor +import org.http4s.blazecore.DispatcherIOFixture +import org.http4s.blazecore.IdleTimeoutStage +import org.http4s.blazecore.QueueTestHead +import org.http4s.blazecore.SlowTestHead +import org.http4s.client.Client +import org.http4s.client.RequestKey +import org.http4s.syntax.all._ + +import java.io.IOException +import java.nio.ByteBuffer +import java.nio.charset.StandardCharsets +import scala.concurrent.TimeoutException +import scala.concurrent.duration._ + +class ClientTimeoutSuite extends CatsEffectSuite with DispatcherIOFixture { + + override def munitTimeout: Duration = 5.seconds + + private def tickWheelFixture = ResourceFixture( + Resource.make(IO(new TickWheelExecutor(tick = 50.millis)))(tickWheel => + IO(tickWheel.shutdown()) + ) + ) + + private def fixture = (tickWheelFixture, dispatcher).mapN(FunFixture.map2(_, _)) + + private val www_foo_com = uri"http://www.foo.com" + private val FooRequest = Request[IO](uri = www_foo_com) + private val FooRequestKey = RequestKey.fromRequest(FooRequest) + private val resp = "HTTP/1.1 200 OK\r\nContent-Length: 4\r\n\r\ndone" + + private val chunkBufferMaxSize = 1024 * 1024 + + private def makeIdleTimeoutStage( + idleTimeout: Duration, + tickWheel: TickWheelExecutor, + ): Option[IdleTimeoutStage[ByteBuffer]] = + idleTimeout match { + case d: FiniteDuration => + Some(new IdleTimeoutStage[ByteBuffer](d, tickWheel, munitExecutionContext)) + case _ => None + } + + private def mkBuffer(s: String): ByteBuffer = + ByteBuffer.wrap(s.getBytes(StandardCharsets.ISO_8859_1)) + + private def mkClient( + head: => HeadStage[ByteBuffer], + tickWheel: TickWheelExecutor, + dispatcher: Dispatcher[IO], + )( + responseHeaderTimeout: Duration = Duration.Inf, + requestTimeout: Duration = Duration.Inf, + idleTimeout: Duration = Duration.Inf, + retries: Int = 0, + ): Client[IO] = { + val manager = ConnectionManager.basic[IO, Http1Connection[IO]]((_: RequestKey) => + IO { + val idleTimeoutStage = makeIdleTimeoutStage(idleTimeout, tickWheel) + val connection = mkConnection(idleTimeoutStage, dispatcher) + val builder = LeafBuilder(connection) + idleTimeoutStage + .fold(builder)(builder.prepend(_)) + .base(head) + connection + } + ) + BlazeClient.makeClient( + manager = manager, + responseHeaderTimeout = responseHeaderTimeout, + requestTimeout = requestTimeout, + scheduler = tickWheel, + ec = munitExecutionContext, + retries = retries, + dispatcher = dispatcher, + ) + } + + private def mkConnection( + idleTimeoutStage: Option[IdleTimeoutStage[ByteBuffer]], + dispatcher: Dispatcher[IO], + ): Http1Connection[IO] = + new Http1Connection[IO]( + requestKey = FooRequestKey, + executionContext = munitExecutionContext, + maxResponseLineSize = 4 * 1024, + maxHeaderLength = 40 * 1024, + maxChunkSize = Int.MaxValue, + chunkBufferMaxSize = chunkBufferMaxSize, + parserMode = ParserMode.Strict, + userAgent = None, + idleTimeoutStage = idleTimeoutStage, + dispatcher = dispatcher, + ) + + fixture.test("Idle timeout on slow response") { case (tickWheel, dispatcher) => + val h = new SlowTestHead(List(mkBuffer(resp)), 60.seconds, tickWheel) + val c = mkClient(h, tickWheel, dispatcher)(idleTimeout = 1.second) + + c.fetchAs[String](FooRequest).intercept[TimeoutException] + } + + fixture.test("Request timeout on slow response") { case (tickWheel, dispatcher) => + val h = new SlowTestHead(List(mkBuffer(resp)), 60.seconds, tickWheel) + val c = mkClient(h, tickWheel, dispatcher)(requestTimeout = 1.second) + + c.fetchAs[String](FooRequest).intercept[TimeoutException] + } + + fixture.test("Idle timeout on slow request body before receiving response") { + case (tickWheel, dispatcher) => + // Sending request body hangs so the idle timeout will kick-in after 1s and interrupt the request + val body = Stream.emit[IO, Byte](1.toByte) ++ Stream.never[IO] + val req = Request(method = Method.POST, uri = www_foo_com, body = body) + val h = new SlowTestHead(Seq(mkBuffer(resp)), 3.seconds, tickWheel) + val c = mkClient(h, tickWheel, dispatcher)(idleTimeout = 1.second) + + c.fetchAs[String](req).intercept[TimeoutException] + } + + fixture.test("Idle timeout on slow request body while receiving response body") { + case (tickWheel, dispatcher) => + // Sending request body hangs so the idle timeout will kick-in after 1s and interrupt the request. + // But with current implementation the cancellation of the request hangs (waits for the request body). + (for { + _ <- IO.unit + body = Stream.emit[IO, Byte](1.toByte) ++ Stream.never[IO] + req = Request(method = Method.POST, uri = www_foo_com, body = body) + q <- Queue.unbounded[IO, Option[ByteBuffer]] + h = new QueueTestHead(q) + (f, b) = resp.splitAt(resp.length - 1) + _ <- (q.offer(Some(mkBuffer(f))) >> IO.sleep(3.seconds) >> q.offer( + Some(mkBuffer(b)) + )).start + c = mkClient(h, tickWheel, dispatcher)(idleTimeout = 1.second) + s <- c.fetchAs[String](req) + } yield s).intercept[TimeoutException] + } + + fixture.test("Not timeout on only marginally slow request body".flaky) { + case (tickWheel, dispatcher) => + // Sending request body will take 1500ms. But there will be some activity every 500ms. + // If the idle timeout wasn't reset every time something is sent, it would kick-in after 1 second. + // The chunks need to be larger than the buffer in CachingChunkWriter + val body = Stream + .fixedRate[IO](500.millis) + .take(3) + .mapChunks(_ => Chunk.array(Array.fill(chunkBufferMaxSize + 1)(1.toByte))) + val req = Request(method = Method.POST, uri = www_foo_com, body = body) + val h = new SlowTestHead(Seq(mkBuffer(resp)), 2000.millis, tickWheel) + val c = mkClient(h, tickWheel, dispatcher)(idleTimeout = 1.second) + + c.fetchAs[String](req) + } + + fixture.test("Request timeout on slow response body".flaky) { case (tickWheel, dispatcher) => + val h = new SlowTestHead(Seq(mkBuffer(resp)), 1500.millis, tickWheel) + val c = mkClient(h, tickWheel, dispatcher)(requestTimeout = 1.second, idleTimeout = 10.second) + + c.fetchAs[String](FooRequest).intercept[TimeoutException] + } + + fixture.test("Idle timeout on slow response body") { case (tickWheel, dispatcher) => + val (f, b) = resp.splitAt(resp.length - 1) + (for { + q <- Queue.unbounded[IO, Option[ByteBuffer]] + _ <- q.offer(Some(mkBuffer(f))) + _ <- (IO.sleep(1500.millis) >> q.offer(Some(mkBuffer(b)))).start + h = new QueueTestHead(q) + c = mkClient(h, tickWheel, dispatcher)(idleTimeout = 500.millis) + s <- c.fetchAs[String](FooRequest) + } yield s).intercept[TimeoutException] + } + + fixture.test("Response head timeout on slow header") { case (tickWheel, dispatcher) => + val h = new SlowTestHead(Seq(mkBuffer(resp)), 10.seconds, tickWheel) + val c = mkClient(h, tickWheel, dispatcher)(responseHeaderTimeout = 500.millis) + c.fetchAs[String](FooRequest).intercept[TimeoutException] + } + + fixture.test("No Response head timeout on fast header".flaky) { case (tickWheel, dispatcher) => + val (f, b) = resp.splitAt(resp.indexOf("\r\n\r\n" + 4)) + val h = new SlowTestHead(Seq(f, b).map(mkBuffer), 125.millis, tickWheel) + // header is split into two chunks, we wait for 10x + val c = mkClient(h, tickWheel, dispatcher)(responseHeaderTimeout = 1250.millis) + + c.fetchAs[String](FooRequest).assertEquals("done") + } + + // Regression test for: https://github.com/http4s/http4s/issues/2386 + // and https://github.com/http4s/http4s/issues/2338 + fixture.test("Eventually timeout on connect timeout") { case (tickWheel, dispatcher) => + val manager = ConnectionManager.basic[IO, BlazeConnection[IO]] { _ => + // In a real use case this timeout is under OS's control (AsynchronousSocketChannel.connect) + IO.sleep(1000.millis) *> IO.raiseError[BlazeConnection[IO]](new IOException()) + } + val c = BlazeClient.makeClient( + manager = manager, + responseHeaderTimeout = Duration.Inf, + requestTimeout = 50.millis, + scheduler = tickWheel, + ec = munitExecutionContext, + retries = 0, + dispatcher = dispatcher, + ) + + // if the .timeout(1500.millis) is hit, it's a TimeoutException, + // if the requestTimeout is hit then it's a TimeoutException + // if establishing connection fails first then it's an IOException + + // The expected behaviour is that the requestTimeout will happen first, + // but will not be considered as long as BlazeClient is busy trying to obtain the connection. + // Obtaining the connection will fail after 1000 millis and that error will be propagated. + c.fetchAs[String](FooRequest).timeout(1500.millis).intercept[IOException] + } +} diff --git a/blaze-client/src/test/scala/org/http4s/blaze/client/Http1ClientStageSuite.scala b/blaze-client/src/test/scala/org/http4s/blaze/client/Http1ClientStageSuite.scala new file mode 100644 index 000000000..ac7e18333 --- /dev/null +++ b/blaze-client/src/test/scala/org/http4s/blaze/client/Http1ClientStageSuite.scala @@ -0,0 +1,337 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s +package blaze +package client + +import cats.effect._ +import cats.effect.kernel.Deferred +import cats.effect.std.Dispatcher +import cats.effect.std.Queue +import cats.syntax.all._ +import fs2.Stream +import munit.CatsEffectSuite +import org.http4s.BuildInfo +import org.http4s.blaze.client.bits.DefaultUserAgent +import org.http4s.blaze.pipeline.Command.EOF +import org.http4s.blaze.pipeline.LeafBuilder +import org.http4s.blazecore.DispatcherIOFixture +import org.http4s.blazecore.QueueTestHead +import org.http4s.blazecore.SeqTestHead +import org.http4s.blazecore.TestHead +import org.http4s.client.RequestKey +import org.http4s.headers.`User-Agent` +import org.http4s.syntax.all._ + +import java.nio.ByteBuffer +import java.nio.charset.StandardCharsets +import scala.concurrent.Future +import scala.concurrent.duration._ + +class Http1ClientStageSuite extends CatsEffectSuite with DispatcherIOFixture { + + private val trampoline = org.http4s.blaze.util.Execution.trampoline + + private val www_foo_test = uri"http://www.foo.test" + private val FooRequest = Request[IO](uri = www_foo_test) + private val FooRequestKey = RequestKey.fromRequest(FooRequest) + + // Common throw away response + val resp = "HTTP/1.1 200 OK\r\nContent-Length: 4\r\n\r\ndone" + + private val fooConnection = + ResourceFixture[Http1Connection[IO]] { + for { + dispatcher <- Dispatcher[IO] + connection <- mkConnection(FooRequestKey, dispatcher) + } yield connection + } + + private def mkConnection( + key: RequestKey, + dispatcher: Dispatcher[IO], + userAgent: Option[`User-Agent`] = None, + ) = + Resource.make( + IO( + new Http1Connection[IO]( + key, + executionContext = trampoline, + maxResponseLineSize = 4096, + maxHeaderLength = 40960, + maxChunkSize = Int.MaxValue, + chunkBufferMaxSize = 1024, + parserMode = ParserMode.Strict, + userAgent = userAgent, + idleTimeoutStage = None, + dispatcher = dispatcher, + ) + ) + )(c => IO(c.shutdown())) + + private def mkBuffer(s: String): ByteBuffer = + ByteBuffer.wrap(s.getBytes(StandardCharsets.ISO_8859_1)) + + private def bracketResponse[T]( + req: Request[IO], + resp: String, + dispatcher: Dispatcher[IO], + ): Resource[IO, Response[IO]] = { + val connectionR = mkConnection(FooRequestKey, dispatcher) + + def stageResource(connection: Http1Connection[IO]) = + Resource.eval(IO { + val h = new SeqTestHead(resp.toSeq.map { chr => + val b = ByteBuffer.allocate(1) + b.put(chr.toByte).flip() + b + }) + LeafBuilder(connection).base(h) + }) + + for { + connection <- connectionR + _ <- stageResource(connection) + resp <- Resource.suspend(connection.runRequest(req, IO.never)) + } yield resp + } + + private def getSubmission( + req: Request[IO], + resp: String, + stage: Http1Connection[IO], + ): IO[(String, String)] = + for { + q <- Queue.unbounded[IO, Option[ByteBuffer]] + h = new QueueTestHead(q) + d <- Deferred[IO, Unit] + _ <- IO(LeafBuilder(stage).base(h)) + _ <- (d.get >> Stream + .emits(resp.toList) + .map { c => + val b = ByteBuffer.allocate(1) + b.put(c.toByte).flip() + b + } + .noneTerminate + .evalMap(q.offer) + .compile + .drain).start + req0 = req.pipeBodyThrough(_.onFinalizeWeak(d.complete(()).void)) + response <- stage.runRequest(req0, IO.never) + result <- response.use(_.as[String]) + _ <- IO(h.stageShutdown()) + buff <- IO.fromFuture(IO(h.result)) + _ <- d.get + request = new String(buff.array(), StandardCharsets.ISO_8859_1) + } yield (request, result) + + private def getSubmission( + req: Request[IO], + resp: String, + dispatcher: Dispatcher[IO], + userAgent: Option[`User-Agent`] = None, + ): IO[(String, String)] = { + val key = RequestKey.fromRequest(req) + mkConnection(key, dispatcher, userAgent).use(tail => getSubmission(req, resp, tail)) + } + + dispatcher.test("Run a basic request".flaky) { dispatcher => + getSubmission(FooRequest, resp, dispatcher).map { case (request, response) => + val statusLine = request.split("\r\n").apply(0) + assertEquals(statusLine, "GET / HTTP/1.1") + assertEquals(response, "done") + } + } + + dispatcher.test("Submit a request line with a query".flaky) { dispatcher => + val uri = "/huh?foo=bar" + val Right(parsed) = Uri.fromString("http://www.foo.test" + uri) + val req = Request[IO](uri = parsed) + + getSubmission(req, resp, dispatcher).map { case (request, response) => + val statusLine = request.split("\r\n").apply(0) + assertEquals(statusLine, "GET " + uri + " HTTP/1.1") + assertEquals(response, "done") + } + } + + fooConnection.test("Fail when attempting to get a second request with one in progress") { tail => + val (frag1, frag2) = resp.splitAt(resp.length - 1) + + val h = new SeqTestHead(List(mkBuffer(frag1), mkBuffer(frag2), mkBuffer(resp))) + LeafBuilder(tail).base(h) + + (for { + _ <- tail.runRequest(FooRequest, IO.never) // we remain in the body + _ <- tail.runRequest(FooRequest, IO.never) + } yield ()).intercept[Http1Connection.InProgressException.type] + } + + fooConnection.test("Alert the user if the body is to short") { tail => + val resp = "HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\ndone" + + val h = new SeqTestHead(List(mkBuffer(resp))) + LeafBuilder(tail).base(h) + + Resource + .suspend(tail.runRequest(FooRequest, IO.never)) + .use(_.body.compile.drain) + .intercept[InvalidBodyException] + } + + dispatcher.test("Interpret a lack of length with a EOF as a valid message") { dispatcher => + val resp = "HTTP/1.1 200 OK\r\n\r\ndone" + + getSubmission(FooRequest, resp, dispatcher).map(_._2).assertEquals("done") + } + + dispatcher.test("Utilize a provided Host header".flaky) { dispatcher => + val resp = "HTTP/1.1 200 OK\r\n\r\ndone" + + val req = FooRequest.withHeaders(headers.Host("bar.test")) + + getSubmission(req, resp, dispatcher).map { case (request, response) => + val requestLines = request.split("\r\n").toList + assert(requestLines.contains("Host: bar.test")) + assertEquals(response, "done") + } + } + + dispatcher.test("Insert a User-Agent header") { dispatcher => + val resp = "HTTP/1.1 200 OK\r\n\r\ndone" + + getSubmission(FooRequest, resp, dispatcher, DefaultUserAgent).map { case (request, response) => + val requestLines = request.split("\r\n").toList + assert(requestLines.contains(s"User-Agent: http4s-blaze/${BuildInfo.version}")) + assertEquals(response, "done") + } + } + + dispatcher.test("Use User-Agent header provided in Request".flaky) { dispatcher => + val resp = "HTTP/1.1 200 OK\r\n\r\ndone" + val req = FooRequest.withHeaders(`User-Agent`(ProductId("myagent"))) + + getSubmission(req, resp, dispatcher).map { case (request, response) => + val requestLines = request.split("\r\n").toList + assert(requestLines.contains("User-Agent: myagent")) + assertEquals(response, "done") + } + } + + fooConnection.test("Not add a User-Agent header when configured with None") { tail => + val resp = "HTTP/1.1 200 OK\r\n\r\ndone" + + getSubmission(FooRequest, resp, tail).map { case (request, response) => + val requestLines = request.split("\r\n").toList + assertEquals(requestLines.find(_.startsWith("User-Agent")), None) + assertEquals(response, "done") + } + } + + // TODO fs2 port - Currently is elevating the http version to 1.1 causing this test to fail + dispatcher.test("Allow an HTTP/1.0 request without a Host header".ignore) { dispatcher => + val resp = "HTTP/1.0 200 OK\r\n\r\ndone" + + val req = Request[IO](uri = www_foo_test, httpVersion = HttpVersion.`HTTP/1.0`) + + getSubmission(req, resp, dispatcher).map { case (request, response) => + assert(!request.contains("Host:")) + assertEquals(response, "done") + } + } + + dispatcher.test("Support flushing the prelude") { dispatcher => + val req = Request[IO](uri = www_foo_test, httpVersion = HttpVersion.`HTTP/1.0`) + /* + * We flush the prelude first to test connection liveness in pooled + * scenarios before we consume the body. Make sure we can handle + * it. Ensure that we still get a well-formed response. + */ + getSubmission(req, resp, dispatcher).map(_._2).assertEquals("done") + } + + fooConnection.test("Not expect body if request was a HEAD request") { tail => + val contentLength = 12345L + val resp = s"HTTP/1.1 200 OK\r\nContent-Length: $contentLength\r\n\r\n" + val headRequest = FooRequest.withMethod(Method.HEAD) + + val h = new SeqTestHead(List(mkBuffer(resp))) + LeafBuilder(tail).base(h) + + Resource.suspend(tail.runRequest(headRequest, IO.never)).use { response => + assertEquals(response.contentLength, Some(contentLength)) + + // body is empty due to it being HEAD request + response.body.compile.toVector.map(_.foldLeft(0L)((long, _) => long + 1L)).assertEquals(0L) + } + } + + { + val resp = "HTTP/1.1 200 OK\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + "3\r\n" + + "foo\r\n" + + "0\r\n" + + "Foo:Bar\r\n" + + "\r\n" + + val req = Request[IO](uri = www_foo_test, httpVersion = HttpVersion.`HTTP/1.1`) + + dispatcher.test("Support trailer headers") { dispatcher => + val hs: IO[Headers] = bracketResponse(req, resp, dispatcher).use { (response: Response[IO]) => + for { + _ <- response.as[String] + hs <- response.trailerHeaders + } yield hs + } + + hs.map(_.headers.mkString).assertEquals("Foo: Bar") + } + + dispatcher.test("Fail to get trailers before they are complete") { dispatcher => + val hs: IO[Headers] = bracketResponse(req, resp, dispatcher).use { (response: Response[IO]) => + for { + hs <- response.trailerHeaders + } yield hs + } + + hs.intercept[IllegalStateException] + } + } + + fooConnection.test("Close idle connection after server closes it") { tail => + val h = new TestHead("EofingTestHead") { + private val bodyIt = Seq(mkBuffer(resp)).iterator + + override def readRequest(size: Int): Future[ByteBuffer] = + synchronized { + if (!closed && bodyIt.hasNext) Future.successful(bodyIt.next()) + else Future.failed(EOF) + } + } + LeafBuilder(tail).base(h) + + for { + _ <- tail.runRequest(FooRequest, IO.never) // the first request succeeds + _ <- IO.sleep(200.millis) // then the server closes the connection + isClosed <- IO( + tail.isClosed + ) // and the client should recognize that the connection has been closed + } yield assert(isClosed) + } +} diff --git a/blaze-client/src/test/scala/org/http4s/blaze/client/PoolManagerSuite.scala b/blaze-client/src/test/scala/org/http4s/blaze/client/PoolManagerSuite.scala new file mode 100644 index 000000000..05f5e5398 --- /dev/null +++ b/blaze-client/src/test/scala/org/http4s/blaze/client/PoolManagerSuite.scala @@ -0,0 +1,265 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s +package blaze +package client + +import cats.effect._ +import cats.effect.std._ +import cats.implicits._ +import com.comcast.ip4s._ +import fs2.Stream +import munit.CatsEffectSuite +import org.http4s.client.ConnectionFailure +import org.http4s.client.RequestKey +import org.http4s.syntax.AllSyntax + +import java.net.InetSocketAddress +import scala.concurrent.ExecutionContext +import scala.concurrent.duration._ + +class PoolManagerSuite extends CatsEffectSuite with AllSyntax { + private val key = + RequestKey(Uri.Scheme.http, Uri.Authority(host = Uri.Ipv4Address(ipv4"127.0.0.1"))) + private val otherKey = RequestKey(Uri.Scheme.http, Uri.Authority(host = Uri.RegName("localhost"))) + + class TestConnection extends Connection[IO] { + @volatile var isClosed = false + val isRecyclable = IO.pure(true) + def requestKey = key + def shutdown() = + isClosed = true + } + + private def mkPool( + maxTotal: Int, + maxWaitQueueLimit: Int = 10, + requestTimeout: Duration = Duration.Inf, + builder: ConnectionBuilder[IO, TestConnection] = _ => IO(new TestConnection()), + maxIdleDuration: Duration = Duration.Inf, + ) = + ConnectionManager.pool( + builder = builder, + maxTotal = maxTotal, + maxWaitQueueLimit = maxWaitQueueLimit, + maxConnectionsPerRequestKey = _ => 5, + responseHeaderTimeout = Duration.Inf, + requestTimeout = requestTimeout, + executionContext = ExecutionContext.Implicits.global, + maxIdleDuration = maxIdleDuration, + ) + + test("A pool manager should wait up to maxWaitQueueLimit") { + (for { + pool <- mkPool(maxTotal = 1, maxWaitQueueLimit = 2) + _ <- pool.borrow(key) + _ <- + Stream(Stream.eval(pool.borrow(key))).repeat + .take(2) + .parJoinUnbounded + .compile + .toList + .attempt + } yield fail("Should have triggered timeout")).timeoutTo(2.seconds, IO.unit) + } + + test("A pool manager should throw at maxWaitQueueLimit") { + for { + pool <- mkPool(maxTotal = 1, maxWaitQueueLimit = 2) + _ <- pool.borrow(key) + att <- + Stream(Stream.eval(pool.borrow(key))).repeat + .take(3) + .parJoinUnbounded + .compile + .toList + .attempt + } yield assertEquals(att, Left(WaitQueueFullFailure())) + } + + test("A pool manager should wake up a waiting connection on release") { + for { + pool <- mkPool(maxTotal = 1, maxWaitQueueLimit = 1) + conn <- pool.borrow(key) + fiber <- pool.borrow(key).start // Should be one waiting + _ <- pool.release(conn.connection) + _ <- fiber.join + } yield () + } + + // this is a regression test for https://github.com/http4s/http4s/issues/2962 + test( + "A pool manager should fail expired connections and then wake up a non-expired waiting connection on release" + ) { + val timeout = 50.milliseconds + for { + pool <- mkPool(maxTotal = 1, maxWaitQueueLimit = 3, requestTimeout = timeout) + conn <- pool.borrow(key) + waiting1 <- pool.borrow(key).void.start + waiting2 <- pool.borrow(key).void.start + _ <- IO.sleep(timeout + 20.milliseconds) + waiting3 <- pool.borrow(key).void.start + _ <- pool.release(conn.connection) + result1 <- waiting1.join + result2 <- waiting2.join + result3 <- waiting3.join + } yield { + assertEquals(result1, Outcome.errored[IO, Throwable, Unit](WaitQueueTimeoutException)) + assertEquals(result2, Outcome.errored[IO, Throwable, Unit](WaitQueueTimeoutException)) + assertEquals(result3, Outcome.succeeded[IO, Throwable, Unit](IO.unit)) + } + } + + test("A pool manager should wake up a waiting connection on invalidate") { + for { + pool <- mkPool(maxTotal = 1, maxWaitQueueLimit = 1) + conn <- pool.borrow(key) + fiber <- pool.borrow(key).start // Should be one waiting + _ <- pool.invalidate(conn.connection) + _ <- fiber.join + } yield () + } + + test("A pool manager should close an idle connection when at max total connections") { + for { + pool <- mkPool(maxTotal = 1, maxWaitQueueLimit = 1) + conn <- pool.borrow(key) + _ <- pool.release(conn.connection) + fiber <- pool.borrow(otherKey).start + _ <- fiber.join + } yield () + } + + test( + "A pool manager should wake up a waiting connection for a different request key on release" + ) { + for { + pool <- mkPool(maxTotal = 1, maxWaitQueueLimit = 1) + conn <- pool.borrow(key) + fiber <- pool.borrow(otherKey).start + _ <- pool.release(conn.connection) + _ <- fiber.join + } yield () + } + + test("A WaitQueueFullFailure should render message properly") { + assert((new WaitQueueFullFailure).toString.contains("Wait queue is full")) + } + + test("A pool manager should continue processing waitQueue after allocation failure".fail) { + for { + isEstablishingConnectionsPossible <- Ref[IO].of(true) + connectionFailure = new ConnectionFailure(key, new InetSocketAddress(1234), new Exception()) + pool <- mkPool( + maxTotal = 1, + maxWaitQueueLimit = 10, + builder = _ => + isEstablishingConnectionsPossible.get + .ifM(IO(new TestConnection()), IO.raiseError(connectionFailure)), + ) + conn1 <- pool.borrow(key) + conn2Fiber <- pool.borrow(key).start + conn3Fiber <- pool.borrow(key).start + _ <- IO.sleep(50.millis) // Give the fibers some time to end up in the waitQueue + _ <- isEstablishingConnectionsPossible.set(false) + _ <- pool.invalidate(conn1.connection) + _ <- conn2Fiber.join + .as(false) + .recover { case _: ConnectionFailure => true } + .assert + .timeout(200.millis) + _ <- conn3Fiber.join + .as(false) + .recover { case _: ConnectionFailure => true } + .assert + .timeout(200.millis) + // After failing to allocate conn2, the pool should attempt to allocate the conn3, + // but it doesn't so we hit the timeoeut. Without the timeout it would be a deadlock. + } yield () + } + + test( + "A pool manager should not deadlock after an attempt to create a connection is canceled".fail + ) { + for { + isEstablishingConnectionsHangs <- Ref[IO].of(true) + connectionAttemptsStarted <- Semaphore[IO](0L) + pool <- mkPool( + maxTotal = 1, + maxWaitQueueLimit = 10, + builder = _ => + connectionAttemptsStarted.release >> + isEstablishingConnectionsHangs.get.ifM(IO.never, IO(new TestConnection())), + ) + conn1Fiber <- pool.borrow(key).start + // wait for the first connection attempt to start before we cancel it + _ <- connectionAttemptsStarted.acquire + _ <- conn1Fiber.cancel + _ <- isEstablishingConnectionsHangs.set(false) + // The first connection attempt is canceled, so it should now be possible to acquire a new connection (but it's not because curAllocated==1==maxTotal) + _ <- pool.borrow(key).timeout(200.millis) + } yield () + } + + test("Should reissue recyclable connections with infinite maxIdleDuration") { + for { + pool <- mkPool( + maxTotal = 1, + maxIdleDuration = Duration.Inf, + ) + conn1 <- pool.borrow(key) + _ <- pool.release(conn1.connection) + conn2 <- pool.borrow(key) + } yield assertEquals(conn1.connection, conn2.connection) + } + + test("Should not reissue recyclable connections before maxIdleDuration") { + for { + pool <- mkPool( + maxTotal = 1, + maxIdleDuration = 365.days, + ) + conn1 <- pool.borrow(key) + _ <- pool.release(conn1.connection) + conn2 <- pool.borrow(key) + } yield assertEquals(conn1.connection, conn2.connection) + } + + test("Should not reissue recyclable connections beyond maxIdleDuration") { + for { + pool <- mkPool( + maxTotal = 1, + maxIdleDuration = Duration.Zero, + ) + conn1 <- pool.borrow(key) + _ <- pool.release(conn1.connection) + conn2 <- pool.borrow(key) + } yield assert(conn1.connection != conn2.connection) + } + + test("Should close connections borrowed beyond maxIdleDuration") { + for { + pool <- mkPool( + maxTotal = 1, + maxIdleDuration = Duration.Zero, + ) + conn1 <- pool.borrow(key) + _ <- pool.release(conn1.connection) + _ <- pool.borrow(key) + } yield assert(conn1.connection.isClosed) + } +} diff --git a/blaze-core/src/main/scala-2.12/org/http4s/blazecore/util/ParasiticExecutionContextCompat.scala b/blaze-core/src/main/scala-2.12/org/http4s/blazecore/util/ParasiticExecutionContextCompat.scala new file mode 100644 index 000000000..8d66e7a39 --- /dev/null +++ b/blaze-core/src/main/scala-2.12/org/http4s/blazecore/util/ParasiticExecutionContextCompat.scala @@ -0,0 +1,25 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s.blazecore.util + +import org.http4s.blaze.util.Execution + +import scala.concurrent.ExecutionContext + +private[util] trait ParasiticExecutionContextCompat { + final def parasitic: ExecutionContext = Execution.trampoline +} diff --git a/blaze-core/src/main/scala-2.13/org/http4s/blazecore/util/ParasiticExecutionContextCompat.scala b/blaze-core/src/main/scala-2.13/org/http4s/blazecore/util/ParasiticExecutionContextCompat.scala new file mode 100644 index 000000000..4223bc6a6 --- /dev/null +++ b/blaze-core/src/main/scala-2.13/org/http4s/blazecore/util/ParasiticExecutionContextCompat.scala @@ -0,0 +1,23 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s.blazecore.util + +import scala.concurrent.ExecutionContext + +private[util] trait ParasiticExecutionContextCompat { + final def parasitic: ExecutionContext = ExecutionContext.parasitic +} diff --git a/blaze-core/src/main/scala-3/org/http4s/blazecore/util/ParasiticExecutionContextCompat.scala b/blaze-core/src/main/scala-3/org/http4s/blazecore/util/ParasiticExecutionContextCompat.scala new file mode 100644 index 000000000..4223bc6a6 --- /dev/null +++ b/blaze-core/src/main/scala-3/org/http4s/blazecore/util/ParasiticExecutionContextCompat.scala @@ -0,0 +1,23 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s.blazecore.util + +import scala.concurrent.ExecutionContext + +private[util] trait ParasiticExecutionContextCompat { + final def parasitic: ExecutionContext = ExecutionContext.parasitic +} diff --git a/blaze-core/src/main/scala/org/http4s/blazecore/BlazeBackendBuilder.scala b/blaze-core/src/main/scala/org/http4s/blazecore/BlazeBackendBuilder.scala new file mode 100644 index 000000000..2ef31d111 --- /dev/null +++ b/blaze-core/src/main/scala/org/http4s/blazecore/BlazeBackendBuilder.scala @@ -0,0 +1,78 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s +package blazecore + +import org.http4s.blaze.channel.ChannelOptions +import org.http4s.blaze.channel.OptionValue + +import java.net.SocketOption +import java.net.StandardSocketOptions + +private[http4s] trait BlazeBackendBuilder[B] { + type Self + + def channelOptions: ChannelOptions + + def channelOption[A](socketOption: SocketOption[A]): Option[A] = + channelOptions.options.collectFirst { + case OptionValue(key, value) if key == socketOption => + value.asInstanceOf[A] + } + def withChannelOptions(channelOptions: ChannelOptions): Self + def withChannelOption[A](key: SocketOption[A], value: A): Self = + withChannelOptions( + ChannelOptions(channelOptions.options.filterNot(_.key == key) :+ OptionValue(key, value)) + ) + def withDefaultChannelOption[A](key: SocketOption[A]): Self = + withChannelOptions(ChannelOptions(channelOptions.options.filterNot(_.key == key))) + + def socketSendBufferSize: Option[Int] = + channelOption(StandardSocketOptions.SO_SNDBUF).map(Int.unbox) + def withSocketSendBufferSize(socketSendBufferSize: Int): Self = + withChannelOption(StandardSocketOptions.SO_SNDBUF, Int.box(socketSendBufferSize)) + def withDefaultSocketSendBufferSize: Self = + withDefaultChannelOption(StandardSocketOptions.SO_SNDBUF) + + def socketReceiveBufferSize: Option[Int] = + channelOption(StandardSocketOptions.SO_RCVBUF).map(Int.unbox) + def withSocketReceiveBufferSize(socketReceiveBufferSize: Int): Self = + withChannelOption(StandardSocketOptions.SO_RCVBUF, Int.box(socketReceiveBufferSize)) + def withDefaultSocketReceiveBufferSize: Self = + withDefaultChannelOption(StandardSocketOptions.SO_RCVBUF) + + def socketKeepAlive: Option[Boolean] = + channelOption(StandardSocketOptions.SO_KEEPALIVE).map(Boolean.unbox) + def withSocketKeepAlive(socketKeepAlive: Boolean): Self = + withChannelOption(StandardSocketOptions.SO_KEEPALIVE, Boolean.box(socketKeepAlive)) + def withDefaultSocketKeepAlive: Self = + withDefaultChannelOption(StandardSocketOptions.SO_KEEPALIVE) + + def socketReuseAddress: Option[Boolean] = + channelOption(StandardSocketOptions.SO_REUSEADDR).map(Boolean.unbox) + def withSocketReuseAddress(socketReuseAddress: Boolean): Self = + withChannelOption(StandardSocketOptions.SO_REUSEADDR, Boolean.box(socketReuseAddress)) + def withDefaultSocketReuseAddress: Self = + withDefaultChannelOption(StandardSocketOptions.SO_REUSEADDR) + + def tcpNoDelay: Option[Boolean] = + channelOption(StandardSocketOptions.TCP_NODELAY).map(Boolean.unbox) + def withTcpNoDelay(tcpNoDelay: Boolean): Self = + withChannelOption(StandardSocketOptions.TCP_NODELAY, Boolean.box(tcpNoDelay)) + def withDefaultTcpNoDelay: Self = + withDefaultChannelOption(StandardSocketOptions.TCP_NODELAY) +} diff --git a/blaze-core/src/main/scala/org/http4s/blazecore/ExecutionContextConfig.scala b/blaze-core/src/main/scala/org/http4s/blazecore/ExecutionContextConfig.scala new file mode 100644 index 000000000..cb4ea7f42 --- /dev/null +++ b/blaze-core/src/main/scala/org/http4s/blazecore/ExecutionContextConfig.scala @@ -0,0 +1,35 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s.blazecore + +import cats.effect.Async +import cats.syntax.all._ + +import scala.concurrent.ExecutionContext + +private[http4s] sealed trait ExecutionContextConfig extends Product with Serializable { + def getExecutionContext[F[_]: Async]: F[ExecutionContext] = this match { + case ExecutionContextConfig.DefaultContext => Async[F].executionContext + case ExecutionContextConfig.ExplicitContext(ec) => ec.pure[F] + } +} + +private[http4s] object ExecutionContextConfig { + case object DefaultContext extends ExecutionContextConfig + final case class ExplicitContext(executionContext: ExecutionContext) + extends ExecutionContextConfig +} diff --git a/blaze-core/src/main/scala/org/http4s/blazecore/Http1Stage.scala b/blaze-core/src/main/scala/org/http4s/blazecore/Http1Stage.scala new file mode 100644 index 000000000..4452dab1a --- /dev/null +++ b/blaze-core/src/main/scala/org/http4s/blazecore/Http1Stage.scala @@ -0,0 +1,332 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s +package blazecore + +import cats.effect.Async +import cats.effect.std.Dispatcher +import cats.syntax.all._ +import fs2.Stream._ +import fs2._ +import org.http4s.blaze.http.parser.BaseExceptions.ParserException +import org.http4s.blaze.pipeline.Command +import org.http4s.blaze.pipeline.TailStage +import org.http4s.blaze.util.BufferTools +import org.http4s.blaze.util.BufferTools.emptyBuffer +import org.http4s.blazecore.util._ +import org.http4s.headers._ +import org.http4s.syntax.header._ +import org.http4s.util.Renderer +import org.http4s.util.StringWriter +import org.http4s.util.Writer + +import java.nio.ByteBuffer +import java.time.Instant +import scala.concurrent.ExecutionContext +import scala.concurrent.Future +import scala.util.Failure +import scala.util.Success + +/** Utility bits for dealing with the HTTP 1.x protocol */ +private[http4s] trait Http1Stage[F[_]] { self: TailStage[ByteBuffer] => + + /** ExecutionContext to be used for all Future continuations + * '''WARNING:''' The ExecutionContext should trampoline or risk possibly unhandled stack overflows + */ + implicit protected def executionContext: ExecutionContext + + implicit protected def F: Async[F] + + implicit protected def dispatcher: Dispatcher[F] + + protected def chunkBufferMaxSize: Int + + protected def doParseContent(buffer: ByteBuffer): Option[ByteBuffer] + + protected def contentComplete(): Boolean + + /** Check Connection header and add applicable headers to response */ + protected final def checkCloseConnection(conn: Connection, rr: StringWriter): Boolean = + if (conn.hasKeepAlive) { // connection, look to the request + logger.trace("Found Keep-Alive header") + false + } else if (conn.hasClose) { + logger.trace("Found Connection:Close header") + rr << "Connection:close\r\n" + true + } else { + logger.info( + s"Unknown connection header: '${conn.value}'. Closing connection upon completion." + ) + rr << "Connection:close\r\n" + true + } + + /** Get the proper body encoder based on the request */ + protected final def getEncoder( + req: Request[F], + rr: StringWriter, + minor: Int, + closeOnFinish: Boolean, + ): Http1Writer[F] = { + val headers = req.headers + getEncoder( + headers.get[Connection], + headers.get[`Transfer-Encoding`], + headers.get[`Content-Length`], + req.trailerHeaders, + rr, + minor, + closeOnFinish, + Http1Stage.omitEmptyContentLength(req), + ) + } + + /** Get the proper body encoder based on the request, + * adding the appropriate Connection and Transfer-Encoding headers along the way + */ + protected final def getEncoder( + connectionHeader: Option[Connection], + bodyEncoding: Option[`Transfer-Encoding`], + lengthHeader: Option[`Content-Length`], + trailer: F[Headers], + rr: StringWriter, + minor: Int, + closeOnFinish: Boolean, + omitEmptyContentLength: Boolean, + ): Http1Writer[F] = + lengthHeader match { + case Some(h) if bodyEncoding.forall(!_.hasChunked) || minor == 0 => + // HTTP 1.1: we have a length and no chunked encoding + // HTTP 1.0: we have a length + + bodyEncoding.foreach(enc => + logger.warn( + s"Unsupported transfer encoding: '${enc.value}' for HTTP 1.$minor. Stripping header." + ) + ) + + logger.trace("Using static encoder") + + rr << h << "\r\n" // write Content-Length + + // add KeepAlive to Http 1.0 responses if the header isn't already present + rr << (if (!closeOnFinish && minor == 0 && connectionHeader.isEmpty) + "Connection: keep-alive\r\n\r\n" + else "\r\n") + + new IdentityWriter[F](h.length, this) + + case _ => // No Length designated for body or Transfer-Encoding included for HTTP 1.1 + if (minor == 0) // we are replying to a HTTP 1.0 request see if the length is reasonable + if (closeOnFinish) { // HTTP 1.0 uses a static encoder + logger.trace("Using static encoder") + rr << "\r\n" + new IdentityWriter[F](-1, this) + } else { // HTTP 1.0, but request was Keep-Alive. + logger.trace("Using static encoder without length") + new CachingStaticWriter[F]( + this + ) // will cache for a bit, then signal close if the body is long + } + else + bodyEncoding match { // HTTP >= 1.1 request without length and/or with chunked encoder + case Some(enc) => // Signaling chunked means flush every chunk + if (!enc.hasChunked) + logger.warn( + s"Unsupported transfer encoding: '${enc.value}' for HTTP 1.$minor. Stripping header." + ) + + if (lengthHeader.isDefined) + logger.warn( + s"Both Content-Length and Transfer-Encoding headers defined. Stripping Content-Length." + ) + + new FlushingChunkWriter(this, trailer) + + case None => // use a cached chunk encoder for HTTP/1.1 without length of transfer encoding + logger.trace("Using Caching Chunk Encoder") + new CachingChunkWriter(this, trailer, chunkBufferMaxSize, omitEmptyContentLength) + } + } + + /** Makes a [[EntityBody]] and a function used to drain the line if terminated early. + * + * @param buffer starting `ByteBuffer` to use in parsing. + * @param eofCondition If the other end hangs up, this is the condition used in the stream for termination. + * The desired result will differ between Client and Server as the former can interpret + * and `Command.EOF` as the end of the body while a server cannot. + */ + protected final def collectBodyFromParser( + buffer: ByteBuffer, + eofCondition: () => Either[Throwable, Option[Chunk[Byte]]], + ): (EntityBody[F], () => Future[ByteBuffer]) = + if (contentComplete()) + if (buffer.remaining() == 0) Http1Stage.CachedEmptyBody + else (EmptyBody, () => Future.successful(buffer)) + // try parsing the existing buffer: many requests will come as a single chunk + else if (buffer.hasRemaining) doParseContent(buffer) match { + case Some(buff) if contentComplete() => + Stream.chunk(Chunk.byteBuffer(buff)) -> Http1Stage + .futureBufferThunk(buffer) + + case Some(buff) => + val (rst, end) = streamingBody(buffer, eofCondition) + (Stream.chunk(Chunk.byteBuffer(buff)) ++ rst, end) + + case None if contentComplete() => + if (buffer.hasRemaining) EmptyBody -> Http1Stage.futureBufferThunk(buffer) + else Http1Stage.CachedEmptyBody + + case None => streamingBody(buffer, eofCondition) + } + // we are not finished and need more data. + else streamingBody(buffer, eofCondition) + + // Streams the body off the wire + private def streamingBody( + buffer: ByteBuffer, + eofCondition: () => Either[Throwable, Option[Chunk[Byte]]], + ): (EntityBody[F], () => Future[ByteBuffer]) = { + @volatile var currentBuffer = buffer + + // TODO: we need to work trailers into here somehow + val t = F.async_[Option[Chunk[Byte]]] { cb => + if (!contentComplete()) { + def go(): Unit = + try { + val parseResult = doParseContent(currentBuffer) + logger.debug(s"Parse result: $parseResult, content complete: ${contentComplete()}") + parseResult match { + case Some(result) => + cb(Either.right(Chunk.byteBuffer(result).some)) + + case None if contentComplete() => + cb(End) + + case None => + channelRead().onComplete { + case Success(b) => + currentBuffer = BufferTools.concatBuffers(currentBuffer, b) + go() + + case Failure(Command.EOF) => + cb(eofCondition()) + + case Failure(t) => + logger.error(t)("Unexpected error reading body.") + cb(Either.left(t)) + } + } + } catch { + case t: ParserException => + fatalError(t, "Error parsing request body") + cb(Either.left(InvalidBodyException(t.getMessage()))) + + case t: Throwable => + fatalError(t, "Error collecting body") + cb(Either.left(t)) + } + go() + } else cb(End) + } + + (repeatEval(t).unNoneTerminate.flatMap(chunk(_)), () => drainBody(currentBuffer)) + } + + /** Called when a fatal error has occurred + * The method logs an error and shuts down the stage, sending the error outbound + * @param t + * @param msg + */ + protected def fatalError(t: Throwable, msg: String): Unit = { + logger.error(t)(s"Fatal Error: $msg") + stageShutdown() + closePipeline(Some(t)) + } + + /** Cleans out any remaining body from the parser */ + protected final def drainBody(buffer: ByteBuffer): Future[ByteBuffer] = { + logger.trace(s"Draining body: $buffer") + + while (!contentComplete() && doParseContent(buffer).nonEmpty) { /* NOOP */ } + + if (contentComplete()) Future.successful(buffer) + else { + // Send the EOF to trigger a connection shutdown + logger.info(s"HTTP body not read to completion. Dropping connection.") + Future.failed(Command.EOF) + } + } +} + +object Http1Stage { + private val CachedEmptyBufferThunk = { + val b = Future.successful(emptyBuffer) + () => b + } + + private val CachedEmptyBody = EmptyBody -> CachedEmptyBufferThunk + + // Building the current Date header value each time is expensive, so we cache it for the current second + private var currentEpoch: Long = _ + private var cachedString: String = _ + + private val NoPayloadMethods: Set[Method] = + Set(Method.GET, Method.DELETE, Method.CONNECT, Method.TRACE) + + private def currentDate: String = { + val now = Instant.now() + val epochSecond = now.getEpochSecond + if (epochSecond != currentEpoch) { + currentEpoch = epochSecond + cachedString = Renderer.renderString(now) + } + cachedString + } + + private def futureBufferThunk(buffer: ByteBuffer): () => Future[ByteBuffer] = + if (buffer.hasRemaining) { () => + Future.successful(buffer) + } else CachedEmptyBufferThunk + + /** Encodes the headers into the Writer. Does not encode + * `Transfer-Encoding` or `Content-Length` headers, which are left + * for the body encoder. Does not encode headers with invalid + * names. Adds `Date` header if one is missing and this is a server + * response. + * + * Note: this method is very niche but useful for both server and client. + */ + def encodeHeaders(headers: Iterable[Header.Raw], rr: Writer, isServer: Boolean): Unit = { + var dateEncoded = false + val dateName = Header[Date].name + headers.foreach { h => + if (h.name != `Transfer-Encoding`.name && h.name != `Content-Length`.name && h.isNameValid) { + if (isServer && h.name == dateName) dateEncoded = true + rr << h << "\r\n" + } + } + + if (isServer && !dateEncoded) + rr << dateName << ": " << currentDate << "\r\n" + () + } + + private def omitEmptyContentLength[F[_]](req: Request[F]) = + NoPayloadMethods.contains(req.method) +} diff --git a/blaze-core/src/main/scala/org/http4s/blazecore/IdleTimeoutStage.scala b/blaze-core/src/main/scala/org/http4s/blazecore/IdleTimeoutStage.scala new file mode 100644 index 000000000..9c2124aa3 --- /dev/null +++ b/blaze-core/src/main/scala/org/http4s/blazecore/IdleTimeoutStage.scala @@ -0,0 +1,165 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s +package blazecore + +import org.http4s.blaze.pipeline.MidStage +import org.http4s.blaze.util.Cancelable +import org.http4s.blaze.util.Execution +import org.http4s.blaze.util.TickWheelExecutor +import org.http4s.blazecore.IdleTimeoutStage.Disabled +import org.http4s.blazecore.IdleTimeoutStage.Enabled +import org.http4s.blazecore.IdleTimeoutStage.ShutDown +import org.http4s.blazecore.IdleTimeoutStage.State + +import java.util.concurrent.TimeoutException +import java.util.concurrent.atomic.AtomicReference +import scala.annotation.tailrec +import scala.concurrent.ExecutionContext +import scala.concurrent.Future +import scala.concurrent.duration.FiniteDuration +import scala.util.control.NonFatal + +private[http4s] final class IdleTimeoutStage[A]( + timeout: FiniteDuration, + exec: TickWheelExecutor, + ec: ExecutionContext, +) extends MidStage[A, A] { stage => + + private val timeoutState = new AtomicReference[State](Disabled) + + override def name: String = "IdleTimeoutStage" + + override def readRequest(size: Int): Future[A] = + channelRead(size).andThen { case _ => resetTimeout() }(Execution.directec) + + override def writeRequest(data: A): Future[Unit] = { + resetTimeout() + channelWrite(data) + } + + override def writeRequest(data: collection.Seq[A]): Future[Unit] = { + resetTimeout() + channelWrite(data) + } + + override protected def stageShutdown(): Unit = { + logger.debug(s"Shutting down idle timeout stage") + + @tailrec def go(): Unit = + timeoutState.get() match { + case old @ IdleTimeoutStage.Enabled(_, cancel) => + if (timeoutState.compareAndSet(old, ShutDown)) cancel.cancel() + else go() + case old => + if (!timeoutState.compareAndSet(old, ShutDown)) go() + } + + go() + + super.stageShutdown() + } + + def init(cb: Callback[TimeoutException]): Unit = setTimeout(cb) + + def setTimeout(cb: Callback[TimeoutException]): Unit = { + logger.debug(s"Starting idle timeout with timeout of ${timeout.toMillis} ms") + + val timeoutTask = new Runnable { + override def run(): Unit = { + val t = new TimeoutException(s"Idle timeout after ${timeout.toMillis} ms.") + logger.debug(t.getMessage) + cb(Right(t)) + } + } + + @tailrec def go(): Unit = + timeoutState.get() match { + case Disabled => + tryScheduling(timeoutTask) match { + case Some(newCancel) => + if (timeoutState.compareAndSet(Disabled, Enabled(timeoutTask, newCancel))) () + else { + newCancel.cancel() + go() + } + case None => () + } + case old @ Enabled(_, oldCancel) => + tryScheduling(timeoutTask) match { + case Some(newCancel) => + if (timeoutState.compareAndSet(old, Enabled(timeoutTask, newCancel))) + oldCancel.cancel() + else { + newCancel.cancel() + go() + } + case None => () + } + case _ => () + } + + go() + } + + @tailrec private def resetTimeout(): Unit = + timeoutState.get() match { + case old @ Enabled(timeoutTask, oldCancel) => + tryScheduling(timeoutTask) match { + case Some(newCancel) => + if (timeoutState.compareAndSet(old, Enabled(timeoutTask, newCancel))) oldCancel.cancel() + else { + newCancel.cancel() + resetTimeout() + } + case None => () + } + case _ => () + } + + @tailrec def cancelTimeout(): Unit = + timeoutState.get() match { + case old @ Enabled(_, cancel) => + if (timeoutState.compareAndSet(old, Disabled)) cancel.cancel() + else cancelTimeout() + case _ => () + } + + def tryScheduling(timeoutTask: Runnable): Option[Cancelable] = + if (exec.isAlive) { + try Some(exec.schedule(timeoutTask, ec, timeout)) + catch { + case TickWheelExecutor.AlreadyShutdownException => + logger.warn(s"Resetting timeout after tickwheelexecutor is shutdown") + None + case NonFatal(e) => throw e + } + } else { + None + } +} + +object IdleTimeoutStage { + + sealed trait State + case object Disabled extends State + // scalafix:off Http4sGeneralLinters; bincompat until 1.0 + case class Enabled(timeoutTask: Runnable, cancel: Cancelable) extends State + // scalafix:on + case object ShutDown extends State + +} diff --git a/blaze-core/src/main/scala/org/http4s/blazecore/ResponseHeaderTimeoutStage.scala b/blaze-core/src/main/scala/org/http4s/blazecore/ResponseHeaderTimeoutStage.scala new file mode 100644 index 000000000..91a633ea2 --- /dev/null +++ b/blaze-core/src/main/scala/org/http4s/blazecore/ResponseHeaderTimeoutStage.scala @@ -0,0 +1,98 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s +package blazecore + +import org.http4s.blaze.pipeline.MidStage +import org.http4s.blaze.util.Cancelable +import org.http4s.blaze.util.TickWheelExecutor + +import java.util.concurrent.TimeoutException +import java.util.concurrent.atomic.AtomicReference +import scala.annotation.tailrec +import scala.concurrent.ExecutionContext +import scala.concurrent.Future +import scala.concurrent.duration.FiniteDuration + +private[http4s] final class ResponseHeaderTimeoutStage[A]( + timeout: FiniteDuration, + exec: TickWheelExecutor, + ec: ExecutionContext, +) extends MidStage[A, A] { stage => + @volatile private[this] var cb: Callback[TimeoutException] = null + + private val timeoutState = new AtomicReference[Cancelable](NoOpCancelable) + + override def name: String = "ResponseHeaderTimeoutStage" + + private val killSwitch = new Runnable { + override def run(): Unit = { + val t = new TimeoutException(s"Response header timeout after ${timeout.toMillis} ms.") + logger.debug(t.getMessage) + cb(Left(t)) + removeStage() + } + } + + override def readRequest(size: Int): Future[A] = + channelRead(size) + + override def writeRequest(data: A): Future[Unit] = { + setTimeout() + channelWrite(data) + } + + override def writeRequest(data: collection.Seq[A]): Future[Unit] = { + setTimeout() + channelWrite(data) + } + + override protected def stageShutdown(): Unit = { + cancelTimeout() + logger.debug(s"Shutting down response header timeout stage") + super.stageShutdown() + } + + override def stageStartup(): Unit = { + super.stageStartup() + logger.debug(s"Starting response header timeout stage with timeout of ${timeout}") + } + + def init(cb: Callback[TimeoutException]): Unit = { + this.cb = cb + stageStartup() + } + + private def setTimeout(): Unit = { + @tailrec + def go(): Unit = { + val prev = timeoutState.get() + if (prev == NoOpCancelable) { + val next = exec.schedule(killSwitch, ec, timeout) + if (!timeoutState.compareAndSet(prev, next)) { + next.cancel() + go() + } else + prev.cancel() + } + } + go() + } + + private def cancelTimeout(): Unit = + timeoutState.getAndSet(NoOpCancelable).cancel() +} diff --git a/blaze-core/src/main/scala/org/http4s/blazecore/package.scala b/blaze-core/src/main/scala/org/http4s/blazecore/package.scala new file mode 100644 index 000000000..7990e5d76 --- /dev/null +++ b/blaze-core/src/main/scala/org/http4s/blazecore/package.scala @@ -0,0 +1,36 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s + +import cats.effect.Resource +import cats.effect.Sync +import org.http4s.blaze.util.Cancelable +import org.http4s.blaze.util.TickWheelExecutor + +package object blazecore { + + private[http4s] def tickWheelResource[F[_]](implicit F: Sync[F]): Resource[F, TickWheelExecutor] = + Resource(F.delay { + val s = new TickWheelExecutor() + (s, F.delay(s.shutdown())) + }) + + private[blazecore] val NoOpCancelable = new Cancelable { + def cancel() = () + override def toString = "no op cancelable" + } +} diff --git a/blaze-core/src/main/scala/org/http4s/blazecore/util/BodylessWriter.scala b/blaze-core/src/main/scala/org/http4s/blazecore/util/BodylessWriter.scala new file mode 100644 index 000000000..03c7d0ed0 --- /dev/null +++ b/blaze-core/src/main/scala/org/http4s/blazecore/util/BodylessWriter.scala @@ -0,0 +1,53 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s +package blazecore +package util + +import cats.effect._ +import cats.syntax.all._ +import fs2._ +import org.http4s.blaze.pipeline._ +import org.http4s.util.StringWriter + +import java.nio.ByteBuffer +import scala.concurrent._ + +/** Discards the body, killing it so as to clean up resources + * + * @param pipe the blaze `TailStage`, which takes ByteBuffers which will send the data downstream + */ +private[http4s] class BodylessWriter[F[_]](pipe: TailStage[ByteBuffer], close: Boolean)(implicit + protected val F: Async[F] +) extends Http1Writer[F] { + def writeHeaders(headerWriter: StringWriter): Future[Unit] = + pipe.channelWrite(Http1Writer.headersToByteBuffer(headerWriter.result)) + + /** Doesn't write the entity body, just the headers. Kills the stream, if an error if necessary + * + * @param p an entity body that will be killed + * @return the F which, when run, will send the headers and kill the entity body + */ + override def writeEntityBody(p: EntityBody[F]): F[Boolean] = + p.compile.drain.as(close) + + override protected def writeEnd(chunk: Chunk[Byte]): Future[Boolean] = + Future.successful(close) + + override protected def writeBodyChunk(chunk: Chunk[Byte], flush: Boolean): Future[Unit] = + FutureUnit +} diff --git a/blaze-core/src/main/scala/org/http4s/blazecore/util/CachingChunkWriter.scala b/blaze-core/src/main/scala/org/http4s/blazecore/util/CachingChunkWriter.scala new file mode 100644 index 000000000..3770e0626 --- /dev/null +++ b/blaze-core/src/main/scala/org/http4s/blazecore/util/CachingChunkWriter.scala @@ -0,0 +1,134 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s +package blazecore +package util + +import cats.effect._ +import cats.effect.std.Dispatcher +import fs2._ +import org.http4s.blaze.pipeline.TailStage +import org.http4s.util.StringWriter + +import java.nio.ByteBuffer +import java.nio.charset.StandardCharsets.ISO_8859_1 +import scala.collection.mutable.Buffer +import scala.concurrent._ + +private[http4s] class CachingChunkWriter[F[_]]( + pipe: TailStage[ByteBuffer], + trailer: F[Headers], + bufferMaxSize: Int, + omitEmptyContentLength: Boolean, +)(implicit + protected val F: Async[F], + private val ec: ExecutionContext, + protected val dispatcher: Dispatcher[F], +) extends Http1Writer[F] { + import ChunkWriter._ + + private[this] var pendingHeaders: StringWriter = _ + private[this] val bodyBuffer: Buffer[Chunk[Byte]] = Buffer() + private[this] var size: Int = 0 + + override def writeHeaders(headerWriter: StringWriter): Future[Unit] = { + pendingHeaders = headerWriter + FutureUnit + } + + private def addChunk(chunk: Chunk[Byte]): Unit = + if (chunk.nonEmpty) { + bodyBuffer += chunk + size += chunk.size + } + + private def toChunkAndClear: Chunk[Byte] = { + val chunk = if (size == 0) { + Chunk.empty + } else if (bodyBuffer.size == 1) { + bodyBuffer.head + } else { + Chunk.concat(bodyBuffer) + } + bodyBuffer.clear() + size = 0 + chunk + } + + override protected def exceptionFlush(): Future[Unit] = + if (size > 0) { + val c = toChunkAndClear + pipe.channelWrite(encodeChunk(c, Nil)) + } else { + FutureUnit + } + + def writeEnd(chunk: Chunk[Byte]): Future[Boolean] = { + addChunk(chunk) + val c = toChunkAndClear + doWriteEnd(c) + } + + private def doWriteEnd(chunk: Chunk[Byte]): Future[Boolean] = { + val f = + if (pendingHeaders != null) { // This is the first write, so we can add a body length instead of chunking + val h = pendingHeaders + pendingHeaders = null + + if (!chunk.isEmpty) { + val body = chunk.toByteBuffer + h << s"Content-Length: ${body.remaining()}\r\n\r\n" + + // Trailers are optional, so dropping because we have no body. + val hbuff = ByteBuffer.wrap(h.result.getBytes(ISO_8859_1)) + pipe.channelWrite(hbuff :: body :: Nil) + } else { + if (!omitEmptyContentLength) + h << s"Content-Length: 0\r\n" + h << "\r\n" + val hbuff = ByteBuffer.wrap(h.result.getBytes(ISO_8859_1)) + pipe.channelWrite(hbuff) + } + } else if (!chunk.isEmpty) { + writeBodyChunk(chunk, true).flatMap { _ => + writeTrailer(pipe, trailer) + } + } else { + writeTrailer(pipe, trailer) + } + + f.map(Function.const(false)) + } + + override protected def writeBodyChunk(chunk: Chunk[Byte], flush: Boolean): Future[Unit] = { + addChunk(chunk) + if (size >= bufferMaxSize || flush) { // time to flush + val c = toChunkAndClear + pipe.channelWrite(encodeChunk(c, Nil)) + } else FutureUnit // Pretend to be done. + } + + private def encodeChunk(chunk: Chunk[Byte], last: List[ByteBuffer]): List[ByteBuffer] = { + val list = ChunkWriter.encodeChunk(chunk, last) + if (pendingHeaders != null) { + pendingHeaders << TransferEncodingChunkedString + val b = ByteBuffer.wrap(pendingHeaders.result.getBytes(ISO_8859_1)) + pendingHeaders = null + b :: list + } else list + } +} diff --git a/blaze-core/src/main/scala/org/http4s/blazecore/util/CachingStaticWriter.scala b/blaze-core/src/main/scala/org/http4s/blazecore/util/CachingStaticWriter.scala new file mode 100644 index 000000000..f3fd92fee --- /dev/null +++ b/blaze-core/src/main/scala/org/http4s/blazecore/util/CachingStaticWriter.scala @@ -0,0 +1,95 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s +package blazecore +package util + +import cats.effect._ +import fs2._ +import org.http4s.blaze.pipeline.TailStage +import org.http4s.util.StringWriter + +import java.nio.ByteBuffer +import scala.collection.mutable.Buffer +import scala.concurrent.Future + +private[http4s] class CachingStaticWriter[F[_]]( + out: TailStage[ByteBuffer], + bufferSize: Int = 8 * 1024, +)(implicit protected val F: Async[F]) + extends Http1Writer[F] { + @volatile + private var _forceClose = false + private val bodyBuffer: Buffer[Chunk[Byte]] = Buffer() + private var writer: StringWriter = null + private var innerWriter: InnerWriter = _ + + def writeHeaders(headerWriter: StringWriter): Future[Unit] = { + this.writer = headerWriter + FutureUnit + } + + private def addChunk(chunk: Chunk[Byte]): Unit = { + bodyBuffer += chunk + () + } + + private def toChunk: Chunk[Byte] = Chunk.concat(bodyBuffer) + + private def clear(): Unit = bodyBuffer.clear() + + override protected def exceptionFlush(): Future[Unit] = { + val c = toChunk + clear() + + if (innerWriter == null) { // We haven't written anything yet + writer << "\r\n" + new InnerWriter().writeBodyChunk(c, flush = true) + } else writeBodyChunk(c, flush = true) // we are already proceeding + } + + override protected def writeEnd(chunk: Chunk[Byte]): Future[Boolean] = + if (innerWriter != null) innerWriter.writeEnd(chunk) + else { // We are finished! Write the length and the keep alive + addChunk(chunk) + val c = toChunk + clear() + writer << "Content-Length: " << c.size << "\r\nConnection: keep-alive\r\n\r\n" + + new InnerWriter().writeEnd(c).map(_ || _forceClose)(parasitic) + } + + override protected def writeBodyChunk(chunk: Chunk[Byte], flush: Boolean): Future[Unit] = + if (innerWriter != null) innerWriter.writeBodyChunk(chunk, flush) + else { + addChunk(chunk) + val c = toChunk + if (flush || c.size >= bufferSize) { // time to just abort and stream it + _forceClose = true + writer << "\r\n" + innerWriter = new InnerWriter + innerWriter.writeBodyChunk(chunk, flush) + } else FutureUnit + } + + // Make the write stuff public + private class InnerWriter extends IdentityWriter(-1, out) { + override def writeEnd(chunk: Chunk[Byte]): Future[Boolean] = super.writeEnd(chunk) + override def writeBodyChunk(chunk: Chunk[Byte], flush: Boolean): Future[Unit] = + super.writeBodyChunk(chunk, flush) + } +} diff --git a/blaze-core/src/main/scala/org/http4s/blazecore/util/ChunkWriter.scala b/blaze-core/src/main/scala/org/http4s/blazecore/util/ChunkWriter.scala new file mode 100644 index 000000000..698c2806a --- /dev/null +++ b/blaze-core/src/main/scala/org/http4s/blazecore/util/ChunkWriter.scala @@ -0,0 +1,79 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s +package blazecore +package util + +import cats.effect.Async +import cats.effect.std.Dispatcher +import cats.syntax.all._ +import fs2._ +import org.http4s.blaze.pipeline.TailStage +import org.http4s.util.StringWriter + +import java.nio.ByteBuffer +import java.nio.charset.StandardCharsets.ISO_8859_1 +import scala.concurrent._ + +private[util] object ChunkWriter { + val CRLFBytes: Array[Byte] = "\r\n".getBytes(ISO_8859_1) + private[this] val CRLFBuffer = ByteBuffer.wrap(CRLFBytes).asReadOnlyBuffer() + def CRLF: ByteBuffer = CRLFBuffer.duplicate() + + private[this] val chunkEndBuffer = + ByteBuffer.wrap("0\r\n\r\n".getBytes(ISO_8859_1)).asReadOnlyBuffer() + def ChunkEndBuffer: ByteBuffer = chunkEndBuffer.duplicate() + + val TransferEncodingChunkedString = "Transfer-Encoding: chunked\r\n\r\n" + private[this] val TransferEncodingChunkedBytes = + "Transfer-Encoding: chunked\r\n\r\n".getBytes(ISO_8859_1) + private[this] val transferEncodingChunkedBuffer = + ByteBuffer.wrap(TransferEncodingChunkedBytes).asReadOnlyBuffer + def TransferEncodingChunked: ByteBuffer = transferEncodingChunkedBuffer.duplicate() + + def writeTrailer[F[_]](pipe: TailStage[ByteBuffer], trailer: F[Headers])(implicit + F: Async[F], + ec: ExecutionContext, + dispatcher: Dispatcher[F], + ): Future[Boolean] = { + val f = trailer.map { trailerHeaders => + if (!trailerHeaders.isEmpty) { + val rr = new StringWriter(256) + rr << "0\r\n" // Last chunk + trailerHeaders.foreach { h => + rr << h << "\r\n"; () + } // trailers + rr << "\r\n" // end of chunks + ByteBuffer.wrap(rr.result.getBytes(ISO_8859_1)) + } else ChunkEndBuffer + } + for { + buffer <- dispatcher.unsafeToFuture(f) + _ <- pipe.channelWrite(buffer) + } yield false + } + + def writeLength(length: Long): ByteBuffer = { + val bytes = length.toHexString.getBytes(ISO_8859_1) + val b = ByteBuffer.allocate(bytes.length + 2) + b.put(bytes).put(CRLFBytes).flip() + b + } + + def encodeChunk(chunk: Chunk[Byte], last: List[ByteBuffer]): List[ByteBuffer] = + writeLength(chunk.size.toLong) :: chunk.toByteBuffer :: CRLF :: last +} diff --git a/blaze-core/src/main/scala/org/http4s/blazecore/util/EntityBodyWriter.scala b/blaze-core/src/main/scala/org/http4s/blazecore/util/EntityBodyWriter.scala new file mode 100644 index 000000000..fa7f5f72e --- /dev/null +++ b/blaze-core/src/main/scala/org/http4s/blazecore/util/EntityBodyWriter.scala @@ -0,0 +1,88 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s +package blazecore +package util + +import cats.effect._ +import cats.syntax.all._ +import fs2._ + +import scala.concurrent._ + +private[http4s] trait EntityBodyWriter[F[_]] { + implicit protected def F: Async[F] + + protected val wroteHeader: Promise[Unit] = Promise[Unit]() + + /** Write a Chunk to the wire. + * + * @param chunk BodyChunk to write to wire + * @return a future letting you know when its safe to continue + */ + protected def writeBodyChunk(chunk: Chunk[Byte], flush: Boolean): Future[Unit] + + /** Write the ending chunk and, in chunked encoding, a trailer to the + * wire. + * + * @param chunk BodyChunk to write to wire + * @return a future letting you know when its safe to continue (if `false`) or + * to close the connection (if `true`) + */ + protected def writeEnd(chunk: Chunk[Byte]): Future[Boolean] + + /** Called in the event of an Await failure to alert the pipeline to cleanup */ + protected def exceptionFlush(): Future[Unit] = FutureUnit + + /** Creates an effect that writes the contents of the EntityBody to the output. + * The writeBodyEnd triggers if there are no exceptions, and the result will + * be the result of the writeEnd call. + * + * @param p EntityBody to write out + * @return the Task which when run will unwind the Process + */ + def writeEntityBody(p: EntityBody[F]): F[Boolean] = { + val writeBody: F[Unit] = writePipe(p).compile.drain + val writeBodyEnd: F[Boolean] = fromFutureNoShift(F.delay(writeEnd(Chunk.empty))) + writeBody *> writeBodyEnd + } + + /** Writes each of the body chunks, if the write fails it returns + * the failed future which throws an error. + * If it errors the error stream becomes the stream, which performs an + * exception flush and then the stream fails. + */ + private def writePipe(s: Stream[F, Byte]): Stream[F, INothing] = { + def writeChunk(chunk: Chunk[Byte]): F[Unit] = + fromFutureNoShift(F.delay(writeBodyChunk(chunk, flush = false))) + + val writeStream: Stream[F, INothing] = + s.repeatPull { + _.uncons.flatMap { + case None => Pull.pure(None) + case Some((hd, tl)) => Pull.eval(writeChunk(hd)).as(Some(tl)) + } + } + + val errorStream: Throwable => Stream[F, INothing] = e => + Stream + .eval(fromFutureNoShift(F.delay(exceptionFlush()))) + .flatMap(_ => Stream.raiseError[F](e)) + writeStream.handleErrorWith(errorStream) + } + +} diff --git a/blaze-core/src/main/scala/org/http4s/blazecore/util/FlushingChunkWriter.scala b/blaze-core/src/main/scala/org/http4s/blazecore/util/FlushingChunkWriter.scala new file mode 100644 index 000000000..856e64f50 --- /dev/null +++ b/blaze-core/src/main/scala/org/http4s/blazecore/util/FlushingChunkWriter.scala @@ -0,0 +1,54 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s +package blazecore +package util + +import cats.effect.Async +import cats.effect.std.Dispatcher +import fs2._ +import org.http4s.blaze.pipeline.TailStage +import org.http4s.util.StringWriter + +import java.nio.ByteBuffer +import scala.concurrent._ + +private[http4s] class FlushingChunkWriter[F[_]](pipe: TailStage[ByteBuffer], trailer: F[Headers])( + implicit + protected val F: Async[F], + private val ec: ExecutionContext, + protected val dispatcher: Dispatcher[F], +) extends Http1Writer[F] { + import ChunkWriter._ + + protected def writeBodyChunk(chunk: Chunk[Byte], flush: Boolean): Future[Unit] = + if (chunk.isEmpty) FutureUnit + else pipe.channelWrite(encodeChunk(chunk, Nil)) + + protected def writeEnd(chunk: Chunk[Byte]): Future[Boolean] = { + if (!chunk.isEmpty) writeBodyChunk(chunk, true).flatMap { _ => + writeTrailer(pipe, trailer) + } + else writeTrailer(pipe, trailer) + }.map(_ => false)(parasitic) + + override def writeHeaders(headerWriter: StringWriter): Future[Unit] = + // It may be a while before we get another chunk, so we flush now + pipe.channelWrite( + List(Http1Writer.headersToByteBuffer(headerWriter.result), TransferEncodingChunked) + ) +} diff --git a/blaze-core/src/main/scala/org/http4s/blazecore/util/Http1Writer.scala b/blaze-core/src/main/scala/org/http4s/blazecore/util/Http1Writer.scala new file mode 100644 index 000000000..f69c02812 --- /dev/null +++ b/blaze-core/src/main/scala/org/http4s/blazecore/util/Http1Writer.scala @@ -0,0 +1,54 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s +package blazecore +package util + +import cats.effect.kernel.Outcome +import cats.effect.syntax.monadCancel._ +import cats.syntax.all._ +import org.http4s.util.StringWriter +import org.log4s.getLogger + +import java.nio.ByteBuffer +import java.nio.charset.StandardCharsets +import scala.concurrent._ + +private[http4s] trait Http1Writer[F[_]] extends EntityBodyWriter[F] { + final def write(headerWriter: StringWriter, body: EntityBody[F]): F[Boolean] = + fromFutureNoShift(F.delay(writeHeaders(headerWriter))) + .guaranteeCase { + case Outcome.Succeeded(_) => + F.unit + + case Outcome.Errored(_) | Outcome.Canceled() => + body.compile.drain.handleError { t2 => + Http1Writer.logger.error(t2)("Error draining body") + } + } >> writeEntityBody(body) + + /* Writes the header. It is up to the writer whether to flush immediately or to + * buffer the header with a subsequent chunk. */ + def writeHeaders(headerWriter: StringWriter): Future[Unit] +} + +private[util] object Http1Writer { + private val logger = getLogger + + def headersToByteBuffer(headers: String): ByteBuffer = + ByteBuffer.wrap(headers.getBytes(StandardCharsets.ISO_8859_1)) +} diff --git a/blaze-core/src/main/scala/org/http4s/blazecore/util/Http2Writer.scala b/blaze-core/src/main/scala/org/http4s/blazecore/util/Http2Writer.scala new file mode 100644 index 000000000..64377bbee --- /dev/null +++ b/blaze-core/src/main/scala/org/http4s/blazecore/util/Http2Writer.scala @@ -0,0 +1,68 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s +package blazecore +package util + +import cats.effect._ +import fs2._ +import org.http4s.blaze.http.Headers +import org.http4s.blaze.http.http2.DataFrame +import org.http4s.blaze.http.http2.HeadersFrame +import org.http4s.blaze.http.http2.Priority +import org.http4s.blaze.http.http2.StreamFrame +import org.http4s.blaze.pipeline.TailStage + +import scala.concurrent._ + +private[http4s] class Http2Writer[F[_]]( + tail: TailStage[StreamFrame], + private var headers: Headers, +)(implicit protected val F: Async[F]) + extends EntityBodyWriter[F] { + override protected def writeEnd(chunk: Chunk[Byte]): Future[Boolean] = { + val f = + if (headers == null) tail.channelWrite(DataFrame(endStream = true, chunk.toByteBuffer)) + else { + val hs = headers + headers = null + if (chunk.isEmpty) + tail.channelWrite(HeadersFrame(Priority.NoPriority, endStream = true, hs)) + else + tail.channelWrite( + HeadersFrame(Priority.NoPriority, endStream = false, hs) + :: DataFrame(endStream = true, chunk.toByteBuffer) + :: Nil + ) + } + + f.map(Function.const(false))(parasitic) + } + + override protected def writeBodyChunk(chunk: Chunk[Byte], flush: Boolean): Future[Unit] = + if (chunk.isEmpty) FutureUnit + else if (headers == null) tail.channelWrite(DataFrame(endStream = false, chunk.toByteBuffer)) + else { + val hs = headers + headers = null + tail.channelWrite( + HeadersFrame(Priority.NoPriority, endStream = false, hs) + :: DataFrame(endStream = false, chunk.toByteBuffer) + :: Nil + ) + } +} diff --git a/blaze-core/src/main/scala/org/http4s/blazecore/util/IdentityWriter.scala b/blaze-core/src/main/scala/org/http4s/blazecore/util/IdentityWriter.scala new file mode 100644 index 000000000..0402b275d --- /dev/null +++ b/blaze-core/src/main/scala/org/http4s/blazecore/util/IdentityWriter.scala @@ -0,0 +1,89 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s +package blazecore +package util + +import cats.effect._ +import fs2._ +import org.http4s.blaze.pipeline.TailStage +import org.http4s.util.StringWriter +import org.log4s.getLogger + +import java.nio.ByteBuffer +import scala.concurrent.Future + +private[http4s] class IdentityWriter[F[_]](size: Long, out: TailStage[ByteBuffer])(implicit + protected val F: Async[F] +) extends Http1Writer[F] { + + private[this] val logger = getLogger + private[this] var headers: ByteBuffer = null + + private var bodyBytesWritten = 0L + + private def willOverflow(count: Int) = + if (size < 0L) false + else count.toLong + bodyBytesWritten > size + + def writeHeaders(headerWriter: StringWriter): Future[Unit] = { + headers = Http1Writer.headersToByteBuffer(headerWriter.result) + FutureUnit + } + + protected def writeBodyChunk(chunk: Chunk[Byte], flush: Boolean): Future[Unit] = + if (willOverflow(chunk.size)) { + // never write past what we have promised using the Content-Length header + val msg = + s"Will not write more bytes than what was indicated by the Content-Length header ($size)" + + logger.warn(msg) + + val reducedChunk = chunk.take((size - bodyBytesWritten).toInt) + writeBodyChunk(reducedChunk, flush = true).flatMap(_ => + Future.failed(new IllegalArgumentException(msg)) + )(parasitic) + } else { + val b = chunk.toByteBuffer + + bodyBytesWritten += b.remaining + + if (headers != null) { + val h = headers + headers = null + out.channelWrite(h :: b :: Nil) + } else out.channelWrite(b) + } + + protected def writeEnd(chunk: Chunk[Byte]): Future[Boolean] = { + val total = bodyBytesWritten + chunk.size + + if (size < 0 || total >= size) + writeBodyChunk(chunk, flush = true).map(Function.const(size < 0))( + parasitic + ) // require close if infinite + else { + val msg = s"Expected `Content-Length: $size` bytes, but only $total were written." + + logger.warn(msg) + + writeBodyChunk(chunk, flush = true).flatMap(_ => + Future.failed(new IllegalStateException(msg)) + )(parasitic) + } + } +} diff --git a/blaze-core/src/main/scala/org/http4s/blazecore/util/package.scala b/blaze-core/src/main/scala/org/http4s/blazecore/util/package.scala new file mode 100644 index 000000000..694a3bd8f --- /dev/null +++ b/blaze-core/src/main/scala/org/http4s/blazecore/util/package.scala @@ -0,0 +1,53 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s +package blazecore + +import cats.effect.Async +import org.http4s.blaze.util.Execution.directec + +import scala.concurrent.Future +import scala.util.Failure +import scala.util.Success + +package object util extends ParasiticExecutionContextCompat { + + /** Used as a terminator for streams built from repeatEval */ + private[http4s] val End = Right(None) + + private[http4s] val FutureUnit = + Future.successful(()) + + // Adapted from https://github.com/typelevel/cats-effect/issues/199#issuecomment-401273282 + /** Inferior to `Async[F].fromFuture` for general use because it doesn't shift, but + * in blaze internals, we don't want to shift. + */ + private[http4s] def fromFutureNoShift[F[_], A](f: F[Future[A]])(implicit F: Async[F]): F[A] = + F.flatMap(f) { future => + future.value match { + case Some(value) => + F.fromTry(value) + case None => + F.async_ { cb => + future.onComplete { + case Success(a) => cb(Right(a)) + case Failure(t) => cb(Left(t)) + }(directec) + } + } + } +} diff --git a/blaze-core/src/main/scala/org/http4s/blazecore/websocket/Http4sWSStage.scala b/blaze-core/src/main/scala/org/http4s/blazecore/websocket/Http4sWSStage.scala new file mode 100644 index 000000000..bca255526 --- /dev/null +++ b/blaze-core/src/main/scala/org/http4s/blazecore/websocket/Http4sWSStage.scala @@ -0,0 +1,202 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s +package blazecore +package websocket + +import cats.effect._ +import cats.effect.std.Dispatcher +import cats.effect.std.Semaphore +import cats.syntax.all._ +import fs2._ +import fs2.concurrent.SignallingRef +import org.http4s.blaze.pipeline.Command.EOF +import org.http4s.blaze.pipeline.LeafBuilder +import org.http4s.blaze.pipeline.TailStage +import org.http4s.blaze.pipeline.TrunkBuilder +import org.http4s.blaze.util.Execution.directec +import org.http4s.blaze.util.Execution.trampoline +import org.http4s.websocket.ReservedOpcodeException +import org.http4s.websocket.UnknownOpcodeException +import org.http4s.websocket.WebSocket +import org.http4s.websocket.WebSocketCombinedPipe +import org.http4s.websocket.WebSocketFrame +import org.http4s.websocket.WebSocketFrame._ +import org.http4s.websocket.WebSocketSeparatePipe + +import java.net.ProtocolException +import java.util.concurrent.atomic.AtomicBoolean +import scala.concurrent.ExecutionContext +import scala.util.Failure +import scala.util.Success + +private[http4s] class Http4sWSStage[F[_]]( + ws: WebSocket[F], + sentClose: AtomicBoolean, + deadSignal: SignallingRef[F, Boolean], + writeSemaphore: Semaphore[F], + dispatcher: Dispatcher[F], +)(implicit F: Async[F]) + extends TailStage[WebSocketFrame] { + + def name: String = "Http4s WebSocket Stage" + + // ////////////////////// Source and Sink generators //////////////////////// + val isClosed: F[Boolean] = F.delay(sentClose.get()) + val setClosed: F[Boolean] = F.delay(sentClose.compareAndSet(false, true)) + + def evalFrame(frame: WebSocketFrame): F[Unit] = frame match { + case c: Close => setClosed.ifM(writeFrame(c, directec), F.unit) + case _ => writeFrame(frame, directec) + } + + def snkFun(frame: WebSocketFrame): F[Unit] = isClosed.ifM(F.unit, evalFrame(frame)) + + private[this] def writeFrame(frame: WebSocketFrame, ec: ExecutionContext): F[Unit] = + writeSemaphore.permit.use { _ => + F.async_[Unit] { cb => + channelWrite(frame).onComplete { + case Success(res) => cb(Right(res)) + case Failure(t) => cb(Left(t)) + }(ec) + } + } + + private[this] def readFrameTrampoline: F[WebSocketFrame] = + F.async_[WebSocketFrame] { cb => + channelRead().onComplete { + case Success(ws) => cb(Right(ws)) + case Failure(exception) => cb(Left(exception)) + }(trampoline) + } + + /** Read from our websocket. + * + * To stay faithful to the RFC, the following must hold: + * + * - If we receive a ping frame, we MUST reply with a pong frame + * - If we receive a pong frame, we don't need to forward it. + * - If we receive a close frame, it means either one of two things: + * - We sent a close frame prior, meaning we do not need to reply with one. Just end the stream + * - We are the first to receive a close frame, so we try to atomically check a boolean flag, + * to prevent sending two close frames. Regardless, we set the signal for termination of + * the stream afterwards + * + * @return A websocket frame, or a possible IO error. + */ + private[this] def handleRead(): F[WebSocketFrame] = { + def maybeSendClose(c: Close): F[Unit] = + F.delay(sentClose.compareAndSet(false, true)).flatMap { cond => + if (cond) writeFrame(c, trampoline) + else F.unit + } >> deadSignal.set(true) + + readFrameTrampoline + .recoverWith { + case t: ReservedOpcodeException => + F.delay(logger.error(t)("Decoded a websocket frame with a reserved opcode")) *> + F.fromEither(Close(1003)) + case t: UnknownOpcodeException => + F.delay(logger.error(t)("Decoded a websocket frame with an unknown opcode")) *> + F.fromEither(Close(1002)) + case t: ProtocolException => + F.delay(logger.error(t)("Websocket protocol violation")) *> F.fromEither(Close(1002)) + } + .flatMap { + case c: Close => + for { + s <- F.delay(sentClose.get()) + // If we sent a close signal, we don't need to reply with one + _ <- if (s) deadSignal.set(true) else maybeSendClose(c) + } yield c + case p @ Ping(d) => + // Reply to ping frame immediately + writeFrame(Pong(d), trampoline) >> F.pure(p) + case rest => + F.pure(rest) + } + } + + /** The websocket input stream + * + * Note: On receiving a close, we MUST send a close back, as stated in section + * 5.5.1 of the websocket spec: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.1 + * + * @return + */ + def inputstream: Stream[F, WebSocketFrame] = + Stream.repeatEval(handleRead()) + + // ////////////////////// Startup and Shutdown //////////////////////// + + override protected def stageStartup(): Unit = { + super.stageStartup() + + // Effect to send a close to the other endpoint + val sendClose: F[Unit] = F.delay(closePipeline(None)) + + val receiveSent: Stream[F, WebSocketFrame] = + ws match { + case WebSocketSeparatePipe(send, receive, _) => + // We don't need to terminate if the send stream terminates. + send.concurrently(receive(inputstream)) + case WebSocketCombinedPipe(receiveSend, _) => + receiveSend(inputstream) + } + + val wsStream = + receiveSent + .evalMap(snkFun) + .drain + .interruptWhen(deadSignal) + .onFinalizeWeak( + ws.onClose.attempt.void + ) // Doing it this way ensures `sendClose` is sent no matter what + .onFinalizeWeak(sendClose) + .compile + .drain + + val result = F.handleErrorWith(wsStream) { + case EOF => + F.delay(stageShutdown()) + case t => + F.delay(logger.error(t)("Error closing Web Socket")) + } + dispatcher.unsafeRunAndForget(result) + } + + override protected def stageShutdown(): Unit = { + val fa = F.handleError(deadSignal.set(true)) { t => + logger.error(t)("Error setting dead signal") + } + dispatcher.unsafeRunAndForget(fa) + super.stageShutdown() + } +} + +object Http4sWSStage { + def bufferingSegment[F[_]](stage: Http4sWSStage[F]): LeafBuilder[WebSocketFrame] = + TrunkBuilder(new SerializingStage[WebSocketFrame]).cap(stage) + + def apply[F[_]]( + ws: WebSocket[F], + sentClose: AtomicBoolean, + deadSignal: SignallingRef[F, Boolean], + dispatcher: Dispatcher[F], + )(implicit F: Async[F]): F[Http4sWSStage[F]] = + Semaphore[F](1L).map(t => new Http4sWSStage(ws, sentClose, deadSignal, t, dispatcher)) +} diff --git a/blaze-core/src/main/scala/org/http4s/blazecore/websocket/Serializer.scala b/blaze-core/src/main/scala/org/http4s/blazecore/websocket/Serializer.scala new file mode 100644 index 000000000..9867a5a1c --- /dev/null +++ b/blaze-core/src/main/scala/org/http4s/blazecore/websocket/Serializer.scala @@ -0,0 +1,141 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s.blazecore.websocket + +import org.http4s.blaze.pipeline.TailStage +import org.http4s.blaze.util.Execution._ + +import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicReference +import scala.collection.mutable.ArrayBuffer +import scala.concurrent.Future +import scala.concurrent.Promise +import scala.concurrent.duration.Duration +import scala.util.Failure +import scala.util.Success +import scala.util.Try + +/** Combined [[WriteSerializer]] and [[ReadSerializer]] */ +private trait Serializer[I] extends WriteSerializer[I] with ReadSerializer[I] + +/** Serializes write requests, storing intermediates in a queue */ +private trait WriteSerializer[I] extends TailStage[I] { self => + + // ////////////////////////////////////////////////////////////////////// + + private var serializerWriteQueue = new ArrayBuffer[I] + private var serializerWritePromise: Promise[Unit] = null + + // / channel writing bits ////////////////////////////////////////////// + override def channelWrite(data: I): Future[Unit] = + channelWrite(data :: Nil) + + override def channelWrite(data: collection.Seq[I]): Future[Unit] = + synchronized { + if (serializerWritePromise == null) { // there is no queue! + serializerWritePromise = Promise[Unit]() + val f = super.channelWrite(data) + f.onComplete(checkQueue)(directec) + f + } else { + serializerWriteQueue ++= data + serializerWritePromise.future + } + } + + private def checkQueue(t: Try[Unit]): Unit = + t match { + case f @ Failure(_) => + val p = synchronized { + serializerWriteQueue.clear() + val p = serializerWritePromise + serializerWritePromise = null + p + } + p.tryComplete(f) + () + + case Success(_) => + synchronized { + if (serializerWriteQueue.isEmpty) + // Nobody has written anything + serializerWritePromise = null + else { + // stuff to write + val f = + if (serializerWriteQueue.length > 1) { // multiple messages, just give them the queue + val a = serializerWriteQueue + serializerWriteQueue = new ArrayBuffer[I](a.size + 10) + super.channelWrite(a) + } else { // only a single element to write, don't send the while queue + val h = serializerWriteQueue.head + serializerWriteQueue.clear() + super.channelWrite(h) + } + + val p = serializerWritePromise + serializerWritePromise = Promise[Unit]() + + f.onComplete { t => + checkQueue(t) + p.complete(t) + }(trampoline) + } + } + } +} + +/** Serializes read requests */ +trait ReadSerializer[I] extends TailStage[I] { + private val serializerReadRef = new AtomicReference[Future[I]](null) + + // / channel reading bits ////////////////////////////////////////////// + + override def channelRead(size: Int = -1, timeout: Duration = Duration.Inf): Future[I] = { + val p = Promise[I]() + val pending = serializerReadRef.getAndSet(p.future) + + if (pending == null) serializerDoRead(p, size, timeout) // no queue, just do a read + else { + val started = if (timeout.isFinite) System.currentTimeMillis() else 0 + pending.onComplete { _ => + val d = if (timeout.isFinite) { + val now = System.currentTimeMillis() + // make sure now is `now` is not before started since + // `currentTimeMillis` can return non-monotonic values. + if (now <= started) timeout + else timeout - Duration(now - started, TimeUnit.MILLISECONDS) + } else timeout + + serializerDoRead(p, size, d) + }(trampoline) + } // there is a queue, need to serialize behind it + + p.future + } + + private def serializerDoRead(p: Promise[I], size: Int, timeout: Duration): Unit = + super + .channelRead(size, timeout) + .onComplete { t => + serializerReadRef.compareAndSet( + p.future, + null, + ) // don't hold our reference if the queue is idle + p.complete(t) + }(directec) +} diff --git a/blaze-core/src/main/scala/org/http4s/blazecore/websocket/SerializingStage.scala b/blaze-core/src/main/scala/org/http4s/blazecore/websocket/SerializingStage.scala new file mode 100644 index 000000000..74ec479d6 --- /dev/null +++ b/blaze-core/src/main/scala/org/http4s/blazecore/websocket/SerializingStage.scala @@ -0,0 +1,33 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s.blazecore.websocket + +import org.http4s.blaze.pipeline.MidStage + +import scala.concurrent.Future + +private final class SerializingStage[I] extends PassThrough[I] with Serializer[I] { + val name: String = "SerializingStage" +} + +private abstract class PassThrough[I] extends MidStage[I, I] { + def readRequest(size: Int): Future[I] = channelRead(size) + + def writeRequest(data: I): Future[Unit] = channelWrite(data) + + override def writeRequest(data: collection.Seq[I]): Future[Unit] = channelWrite(data) +} diff --git a/blaze-core/src/main/scala/org/http4s/blazecore/websocket/WebSocketHandshake.scala b/blaze-core/src/main/scala/org/http4s/blazecore/websocket/WebSocketHandshake.scala new file mode 100644 index 000000000..19eeebd2d --- /dev/null +++ b/blaze-core/src/main/scala/org/http4s/blazecore/websocket/WebSocketHandshake.scala @@ -0,0 +1,148 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s.blazecore.websocket + +import cats.MonadThrow +import cats.effect.std.Random +import cats.syntax.all._ +import org.http4s.crypto.Hash +import org.http4s.crypto.HashAlgorithm +import scodec.bits.ByteVector + +import java.nio.charset.StandardCharsets._ +import java.util.Base64 + +private[http4s] object WebSocketHandshake { + + /** Creates a new [[ClientHandshaker]] */ + def clientHandshaker[F[_]: MonadThrow](host: String, random: Random[F]): F[ClientHandshaker] = + for { + key <- random.nextBytes(16).map(Base64.getEncoder.encodeToString) + acceptKey <- genAcceptKey(key) + } yield new ClientHandshaker(host, key, acceptKey) + + /** Provides the initial headers and a 16 byte Base64 encoded random key for websocket connections */ + class ClientHandshaker(host: String, key: String, acceptKey: String) { + + /** Initial headers to send to the server */ + val initHeaders: List[(String, String)] = + ("Host", host) :: ("Sec-WebSocket-Key", key) :: clientBaseHeaders + + /** Check if the server response is a websocket handshake response */ + def checkResponse(headers: Iterable[(String, String)]): Either[String, Unit] = + if ( + !headers.exists { case (k, v) => + k.equalsIgnoreCase("Connection") && valueContains("Upgrade", v) + } + ) + Left("Bad Connection header") + else if ( + !headers.exists { case (k, v) => + k.equalsIgnoreCase("Upgrade") && v.equalsIgnoreCase("websocket") + } + ) + Left("Bad Upgrade header") + else + headers + .find { case (k, _) => k.equalsIgnoreCase("Sec-WebSocket-Accept") } + .map { + case (_, v) if acceptKey === v => Either.unit + case (_, v) => Left(s"Invalid key: $v") + } + .getOrElse(Left("Missing Sec-WebSocket-Accept header")) + } + + /** Checks the headers received from the client and if they are valid, generates response headers */ + def serverHandshake( + headers: Iterable[(String, String)] + ): Either[(Int, String), collection.Seq[(String, String)]] = + if (!headers.exists { case (k, _) => k.equalsIgnoreCase("Host") }) + Left((-1, "Missing Host Header")) + else if ( + !headers.exists { case (k, v) => + k.equalsIgnoreCase("Connection") && valueContains("Upgrade", v) + } + ) + Left((-1, "Bad Connection header")) + else if ( + !headers.exists { case (k, v) => + k.equalsIgnoreCase("Upgrade") && v.equalsIgnoreCase("websocket") + } + ) + Left((-1, "Bad Upgrade header")) + else if ( + !headers.exists { case (k, v) => + k.equalsIgnoreCase("Sec-WebSocket-Version") && valueContains("13", v) + } + ) + Left((-1, "Bad Websocket Version header")) + // we are past most of the 'just need them' headers + else + headers + .find { case (k, v) => + k.equalsIgnoreCase("Sec-WebSocket-Key") && decodeLen(v) == 16 + } + .map { case (_, key) => + genAcceptKey[Either[Throwable, *]](key) match { + case Left(_) => Left((-1, "Bad Sec-WebSocket-Key header")) + case Right(acceptKey) => + Right( + collection.Seq( + ("Upgrade", "websocket"), + ("Connection", "Upgrade"), + ("Sec-WebSocket-Accept", acceptKey), + ) + ) + } + } + .getOrElse(Left((-1, "Bad Sec-WebSocket-Key header"))) + + /** Check if the headers contain an 'Upgrade: websocket' header */ + def isWebSocketRequest(headers: Iterable[(String, String)]): Boolean = + headers.exists { case (k, v) => + k.equalsIgnoreCase("Upgrade") && v.equalsIgnoreCase("websocket") + } + + private def decodeLen(key: String): Int = Base64.getDecoder.decode(key).length + + private def genAcceptKey[F[_]](str: String)(implicit F: MonadThrow[F]): F[String] = for { + data <- F.fromEither(ByteVector.encodeAscii(str)) + digest <- Hash[F].digest(HashAlgorithm.SHA1, data ++ magicString) + } yield digest.toBase64 + + private[websocket] def valueContains(key: String, value: String): Boolean = { + val parts = value.split(",").map(_.trim) + parts.foldLeft(false)((b, s) => + b || { + s.equalsIgnoreCase(key) || + s.length > 1 && + s.startsWith("\"") && + s.endsWith("\"") && + s.substring(1, s.length - 1).equalsIgnoreCase(key) + } + ) + } + + private val magicString = + ByteVector.view("258EAFA5-E914-47DA-95CA-C5AB0DC85B11".getBytes(US_ASCII)) + + private val clientBaseHeaders = List( + ("Connection", "Upgrade"), + ("Upgrade", "websocket"), + ("Sec-WebSocket-Version", "13"), + ) +} diff --git a/blaze-core/src/test/scala/org/http4s/blazecore/DispatcherIOFixture.scala b/blaze-core/src/test/scala/org/http4s/blazecore/DispatcherIOFixture.scala new file mode 100644 index 000000000..7c5f6355a --- /dev/null +++ b/blaze-core/src/test/scala/org/http4s/blazecore/DispatcherIOFixture.scala @@ -0,0 +1,29 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s +package blazecore + +import cats.effect.IO +import cats.effect.SyncIO +import cats.effect.std.Dispatcher +import munit.CatsEffectSuite + +trait DispatcherIOFixture { this: CatsEffectSuite => + + def dispatcher: SyncIO[FunFixture[Dispatcher[IO]]] = ResourceFixture(Dispatcher[IO]) + +} diff --git a/blaze-core/src/test/scala/org/http4s/blazecore/ResponseParser.scala b/blaze-core/src/test/scala/org/http4s/blazecore/ResponseParser.scala new file mode 100644 index 000000000..88cd6ed7e --- /dev/null +++ b/blaze-core/src/test/scala/org/http4s/blazecore/ResponseParser.scala @@ -0,0 +1,99 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s +package blazecore + +import cats.syntax.all._ +import fs2._ +import org.http4s.blaze.http.parser.Http1ClientParser +import org.typelevel.ci.CIString + +import java.nio.ByteBuffer +import java.nio.charset.StandardCharsets +import scala.annotation.nowarn +import scala.collection.mutable.ListBuffer + +class ResponseParser extends Http1ClientParser { + private val headers = new ListBuffer[(String, String)] + + private var code: Int = -1 + @nowarn private var reason = "" + @nowarn private var majorversion = -1 + @nowarn private var minorversion = -1 + + /** Will not mutate the ByteBuffers in the Seq */ + def parseResponse(buffs: Seq[ByteBuffer]): (Status, Set[Header.Raw], String) = { + val b = + ByteBuffer.wrap(buffs.map(b => Chunk.byteBuffer(b).toArray).toArray.flatten) + parseResponseBuffer(b) + } + + /* Will mutate the ByteBuffer */ + def parseResponseBuffer(buffer: ByteBuffer): (Status, Set[Header.Raw], String) = { + parseResponseLine(buffer) + parseHeaders(buffer) + + if (!headersComplete()) sys.error("Headers didn't complete!") + val body = new ListBuffer[ByteBuffer] + while (!this.contentComplete() && buffer.hasRemaining) + body += parseContent(buffer) + + val bp = { + val bytes = + body.toList.foldLeft(Vector.empty[Chunk[Byte]])((vec, bb) => vec :+ Chunk.byteBuffer(bb)) + new String(Chunk.concat(bytes).toArray, StandardCharsets.ISO_8859_1) + } + + val headers: Set[Header.Raw] = this.headers + .result() + .toSet + .map { (kv: (String, String)) => + Header.Raw(CIString(kv._1), kv._2) + } + + val status = Status.fromInt(this.code).valueOr(throw _) + + (status, headers, bp) + } + + override protected def headerComplete(name: String, value: String): Boolean = { + headers += ((name, value)) + false + } + + override protected def submitResponseLine( + code: Int, + reason: String, + scheme: String, + majorversion: Int, + minorversion: Int, + ): Unit = { + this.code = code + this.reason = reason + this.majorversion = majorversion + this.minorversion = minorversion + } +} + +object ResponseParser { + def apply(buff: Seq[ByteBuffer]): (Status, Set[Header.Raw], String) = + new ResponseParser().parseResponse(buff) + def apply(buff: ByteBuffer): (Status, Set[Header.Raw], String) = parseBuffer(buff) + + def parseBuffer(buff: ByteBuffer): (Status, Set[Header.Raw], String) = + new ResponseParser().parseResponseBuffer(buff) +} diff --git a/blaze-core/src/test/scala/org/http4s/blazecore/TestHead.scala b/blaze-core/src/test/scala/org/http4s/blazecore/TestHead.scala new file mode 100644 index 000000000..662d9eb47 --- /dev/null +++ b/blaze-core/src/test/scala/org/http4s/blazecore/TestHead.scala @@ -0,0 +1,153 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s +package blazecore + +import cats.effect.IO +import cats.effect.std.Queue +import cats.effect.unsafe.implicits.global +import org.http4s.blaze.pipeline.Command._ +import org.http4s.blaze.pipeline.HeadStage +import org.http4s.blaze.util.TickWheelExecutor +import scodec.bits.ByteVector + +import java.nio.ByteBuffer +import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.atomic.AtomicReference +import java.util.function.BinaryOperator +import scala.concurrent.Future +import scala.concurrent.Promise +import scala.concurrent.duration.Duration +import scala.util.Failure +import scala.util.Success +import scala.util.Try + +abstract class TestHead(val name: String) extends HeadStage[ByteBuffer] { + private val acc = new AtomicReference[ByteVector](ByteVector.empty) + private val p = Promise[ByteBuffer]() + private val binaryOperator: BinaryOperator[ByteVector] = (x: ByteVector, y: ByteVector) => x ++ y + + @volatile var closed = false + + @volatile var closeCauses: Vector[Option[Throwable]] = Vector[Option[Throwable]]() + + private[this] val disconnectSent = new AtomicBoolean(false) + + def getBytes(): Array[Byte] = acc.get().toArray + + val result = p.future + + override def writeRequest(data: ByteBuffer): Future[Unit] = + if (closed) Future.failed(EOF) + else { + acc.accumulateAndGet(ByteVector.view(data), binaryOperator) + util.FutureUnit + } + + override def stageShutdown(): Unit = { + closed = true + super.stageShutdown() + p.trySuccess(ByteBuffer.wrap(getBytes())) + () + } + + override def doClosePipeline(cause: Option[Throwable]): Unit = { + closeCauses :+= cause + cause.foreach(logger.error(_)(s"$name received unhandled error command")) + if (disconnectSent.compareAndSet(false, true)) + sendInboundCommand(Disconnected) + } +} + +class SeqTestHead(body: Seq[ByteBuffer]) extends TestHead("SeqTestHead") { + private val bodyIt = body.iterator + + override def readRequest(size: Int): Future[ByteBuffer] = + if (!closed && bodyIt.hasNext) Future.successful(bodyIt.next()) + else { + stageShutdown() + sendInboundCommand(Disconnected) + Future.failed(EOF) + } +} + +final class QueueTestHead(queue: Queue[IO, Option[ByteBuffer]]) extends TestHead("QueueTestHead") { + private val closedP = Promise[Nothing]() + + override def readRequest(size: Int): Future[ByteBuffer] = { + val p = Promise[ByteBuffer]() + p.completeWith( + queue.take + .flatMap { + case Some(bb) => IO.pure(bb) + case None => IO.raiseError(EOF) + } + .unsafeToFuture() + ) + p.completeWith(closedP.future) + p.future + } + + override def stageShutdown(): Unit = { + closedP.tryFailure(EOF) + super.stageShutdown() + } +} + +final class SlowTestHead(body: Seq[ByteBuffer], pause: Duration, scheduler: TickWheelExecutor) + extends TestHead("Slow TestHead") { self => + + private val bodyIt = body.iterator + private var currentRequest: Option[Promise[ByteBuffer]] = None + + private def resolvePending(result: Try[ByteBuffer]): Unit = { + currentRequest.foreach(_.tryComplete(result)) + currentRequest = None + } + + private def clear(): Unit = { + while (bodyIt.hasNext) bodyIt.next() + resolvePending(Failure(EOF)) + } + + override def stageShutdown(): Unit = { + clear() + super.stageShutdown() + } + + override def readRequest(size: Int): Future[ByteBuffer] = + currentRequest match { + case Some(_) => + Future.failed(new IllegalStateException("Cannot serve multiple concurrent read requests")) + case None => + val p = Promise[ByteBuffer]() + currentRequest = Some(p) + + scheduler.schedule( + new Runnable { + override def run(): Unit = + resolvePending { + if (!closed && bodyIt.hasNext) Success(bodyIt.next()) + else Failure(EOF) + } + }, + pause, + ) + + p.future + } +} diff --git a/blaze-core/src/test/scala/org/http4s/blazecore/util/DumpingWriter.scala b/blaze-core/src/test/scala/org/http4s/blazecore/util/DumpingWriter.scala new file mode 100644 index 000000000..68f1af92b --- /dev/null +++ b/blaze-core/src/test/scala/org/http4s/blazecore/util/DumpingWriter.scala @@ -0,0 +1,54 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s +package blazecore +package util + +import cats.effect.Async +import cats.effect.IO +import fs2._ + +import scala.collection.mutable.Buffer +import scala.concurrent.Future + +object DumpingWriter { + def dump(p: EntityBody[IO]): IO[Array[Byte]] = { + val w = new DumpingWriter() + for (_ <- w.writeEntityBody(p)) yield (w.toArray) + } +} + +class DumpingWriter(implicit protected val F: Async[IO]) extends EntityBodyWriter[IO] { + private val buffer = Buffer[Chunk[Byte]]() + + def toArray: Array[Byte] = + buffer.synchronized { + Chunk.concat(buffer).toArray + } + + override protected def writeEnd(chunk: Chunk[Byte]): Future[Boolean] = + buffer.synchronized { + buffer += chunk + Future.successful(false) + } + + override protected def writeBodyChunk(chunk: Chunk[Byte], flush: Boolean): Future[Unit] = + buffer.synchronized { + buffer += chunk + FutureUnit + } +} diff --git a/blaze-core/src/test/scala/org/http4s/blazecore/util/FailingWriter.scala b/blaze-core/src/test/scala/org/http4s/blazecore/util/FailingWriter.scala new file mode 100644 index 000000000..0bb7729e0 --- /dev/null +++ b/blaze-core/src/test/scala/org/http4s/blazecore/util/FailingWriter.scala @@ -0,0 +1,33 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s +package blazecore +package util + +import cats.effect._ +import fs2._ +import org.http4s.blaze.pipeline.Command.EOF + +import scala.concurrent.Future + +class FailingWriter(implicit protected val F: Async[IO]) extends EntityBodyWriter[IO] { + override protected def writeEnd(chunk: Chunk[Byte]): Future[Boolean] = + Future.failed(EOF) + + override protected def writeBodyChunk(chunk: Chunk[Byte], flush: Boolean): Future[Unit] = + Future.failed(EOF) +} diff --git a/blaze-core/src/test/scala/org/http4s/blazecore/util/Http1WriterSpec.scala b/blaze-core/src/test/scala/org/http4s/blazecore/util/Http1WriterSpec.scala new file mode 100644 index 000000000..537e88f75 --- /dev/null +++ b/blaze-core/src/test/scala/org/http4s/blazecore/util/Http1WriterSpec.scala @@ -0,0 +1,336 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s +package blazecore +package util + +import cats.effect._ +import cats.effect.std.Dispatcher +import cats.syntax.all._ +import fs2.Stream._ +import fs2._ +import fs2.compression.Compression +import fs2.compression.DeflateParams +import munit.CatsEffectSuite +import org.http4s.blaze.pipeline.LeafBuilder +import org.http4s.blaze.pipeline.TailStage +import org.http4s.util.StringWriter +import org.typelevel.ci._ + +import java.nio.ByteBuffer +import java.nio.charset.StandardCharsets +import scala.concurrent.Future + +class Http1WriterSpec extends CatsEffectSuite with DispatcherIOFixture { + case object Failed extends RuntimeException + + final def writeEntityBody( + p: EntityBody[IO] + )(builder: TailStage[ByteBuffer] => Http1Writer[IO]): IO[String] = { + val tail = new TailStage[ByteBuffer] { + override def name: String = "TestTail" + } + + val head = new TestHead("TestHead") { + override def readRequest(size: Int): Future[ByteBuffer] = + Future.failed(new Exception("Head doesn't read.")) + } + + LeafBuilder(tail).base(head) + val w = builder(tail) + + for { + _ <- IO.fromFuture(IO(w.writeHeaders(new StringWriter << "Content-Type: text/plain\r\n"))) + _ <- w.writeEntityBody(p).attempt + _ <- IO(head.stageShutdown()) + _ <- IO.fromFuture(IO(head.result)) + } yield new String(head.getBytes(), StandardCharsets.ISO_8859_1) + } + + private val message = "Hello world!" + private val messageBuffer = Chunk.array(message.getBytes(StandardCharsets.ISO_8859_1)) + + final def runNonChunkedTests( + name: String, + builder: Dispatcher[IO] => TailStage[ByteBuffer] => Http1Writer[IO], + ): Unit = { + dispatcher.test(s"$name Write a single emit") { implicit dispatcher => + writeEntityBody(chunk(messageBuffer))(builder(dispatcher)) + .assertEquals("Content-Type: text/plain\r\nContent-Length: 12\r\n\r\n" + message) + } + + dispatcher.test(s"$name Write two emits") { implicit dispatcher => + val p = chunk(messageBuffer) ++ chunk(messageBuffer) + writeEntityBody(p)(builder(dispatcher)) + .assertEquals("Content-Type: text/plain\r\nContent-Length: 24\r\n\r\n" + message + message) + } + + dispatcher.test(s"$name Write an await") { implicit dispatcher => + val p = eval(IO(messageBuffer)).flatMap(chunk(_)) + writeEntityBody(p)(builder(dispatcher)) + .assertEquals("Content-Type: text/plain\r\nContent-Length: 12\r\n\r\n" + message) + } + + dispatcher.test(s"$name Write two awaits") { implicit dispatcher => + val p = eval(IO(messageBuffer)).flatMap(chunk(_)) + writeEntityBody(p ++ p)(builder(dispatcher)) + .assertEquals("Content-Type: text/plain\r\nContent-Length: 24\r\n\r\n" + message + message) + } + + dispatcher.test(s"$name Write a body that fails and falls back") { implicit dispatcher => + val p = eval(IO.raiseError(Failed)).handleErrorWith { _ => + chunk(messageBuffer) + } + writeEntityBody(p)(builder(dispatcher)) + .assertEquals("Content-Type: text/plain\r\nContent-Length: 12\r\n\r\n" + message) + } + + dispatcher.test(s"$name execute cleanup") { implicit dispatcher => + (for { + clean <- Ref.of[IO, Boolean](false) + p = chunk(messageBuffer).onFinalizeWeak(clean.set(true)) + r <- writeEntityBody(p)(builder(dispatcher)) + .map(_ == "Content-Type: text/plain\r\nContent-Length: 12\r\n\r\n" + message) + c <- clean.get + } yield r && c).assert + } + + dispatcher.test(s"$name Write tasks that repeat eval") { implicit dispatcher => + val t = { + var counter = 2 + IO { + counter -= 1 + if (counter >= 0) Some(Chunk.array("foo".getBytes(StandardCharsets.ISO_8859_1))) + else None + } + } + val p = repeatEval(t).unNoneTerminate.flatMap(chunk(_)) ++ chunk( + Chunk.array("bar".getBytes(StandardCharsets.ISO_8859_1)) + ) + writeEntityBody(p)(builder(dispatcher)) + .assertEquals("Content-Type: text/plain\r\nContent-Length: 9\r\n\r\n" + "foofoobar") + } + } + + runNonChunkedTests( + "CachingChunkWriter", + implicit dispatcher => + tail => new CachingChunkWriter[IO](tail, IO.pure(Headers.empty), 1024 * 1024, false), + ) + + runNonChunkedTests( + "CachingStaticWriter", + implicit dispatcher => + tail => new CachingChunkWriter[IO](tail, IO.pure(Headers.empty), 1024 * 1024, false), + ) + + def builder(tail: TailStage[ByteBuffer])(implicit D: Dispatcher[IO]): FlushingChunkWriter[IO] = + new FlushingChunkWriter[IO](tail, IO.pure(Headers.empty)) + + dispatcher.test("FlushingChunkWriter should Write a strict chunk") { implicit d => + // n.b. in the scalaz-stream version, we could introspect the + // stream, note the lack of effects, and write this with a + // Content-Length header. In fs2, this must be chunked. + writeEntityBody(chunk(messageBuffer))(builder).assertEquals("""Content-Type: text/plain + |Transfer-Encoding: chunked + | + |c + |Hello world! + |0 + | + |""".stripMargin.replace("\n", "\r\n")) + } + + dispatcher.test("FlushingChunkWriter should Write two strict chunks") { implicit d => + val p = chunk(messageBuffer) ++ chunk(messageBuffer) + writeEntityBody(p)(builder).assertEquals("""Content-Type: text/plain + |Transfer-Encoding: chunked + | + |c + |Hello world! + |c + |Hello world! + |0 + | + |""".stripMargin.replace("\n", "\r\n")) + } + + dispatcher.test("FlushingChunkWriter should Write an effectful chunk") { implicit d => + // n.b. in the scalaz-stream version, we could introspect the + // stream, note the chunk was followed by halt, and write this + // with a Content-Length header. In fs2, this must be chunked. + val p = eval(IO(messageBuffer)).flatMap(chunk(_)) + writeEntityBody(p)(builder).assertEquals("""Content-Type: text/plain + |Transfer-Encoding: chunked + | + |c + |Hello world! + |0 + | + |""".stripMargin.replace("\n", "\r\n")) + } + + dispatcher.test("FlushingChunkWriter should Write two effectful chunks") { implicit d => + val p = eval(IO(messageBuffer)).flatMap(chunk(_)) + writeEntityBody(p ++ p)(builder).assertEquals("""Content-Type: text/plain + |Transfer-Encoding: chunked + | + |c + |Hello world! + |c + |Hello world! + |0 + | + |""".stripMargin.replace("\n", "\r\n")) + } + + dispatcher.test("FlushingChunkWriter should Elide empty chunks") { implicit d => + // n.b. We don't do anything special here. This is a feature of + // fs2, but it's important enough we should check it here. + val p: Stream[IO, Byte] = chunk(Chunk.empty) ++ chunk(messageBuffer) + writeEntityBody(p)(builder).assertEquals("""Content-Type: text/plain + |Transfer-Encoding: chunked + | + |c + |Hello world! + |0 + | + |""".stripMargin.replace("\n", "\r\n")) + } + + dispatcher.test("FlushingChunkWriter should Write a body that fails and falls back") { + implicit d => + val p = eval(IO.raiseError(Failed)).handleErrorWith { _ => + chunk(messageBuffer) + } + writeEntityBody(p)(builder).assertEquals("""Content-Type: text/plain + |Transfer-Encoding: chunked + | + |c + |Hello world! + |0 + | + |""".stripMargin.replace("\n", "\r\n")) + } + + dispatcher.test("FlushingChunkWriter should execute cleanup") { implicit d => + (for { + clean <- Ref.of[IO, Boolean](false) + p = chunk(messageBuffer).onFinalizeWeak(clean.set(true)) + w <- writeEntityBody(p)(builder).map( + _ == + """Content-Type: text/plain + |Transfer-Encoding: chunked + | + |c + |Hello world! + |0 + | + |""".stripMargin.replace("\n", "\r\n") + ) + c <- clean.get + _ <- clean.set(false) + p2 = eval(IO.raiseError(new RuntimeException("asdf"))).onFinalizeWeak(clean.set(true)) + _ <- writeEntityBody(p2)(builder) + c2 <- clean.get + } yield w && c && c2).assert + } + + // Some tests for the raw unwinding body without HTTP encoding. + test("FlushingChunkWriter should write a deflated stream") { + val s = eval(IO(messageBuffer)).flatMap(chunk(_)) + val p = s.through(Compression[IO].deflate(DeflateParams.DEFAULT)) + ( + p.compile.toVector.map(_.toArray), + DumpingWriter.dump(s.through(Compression[IO].deflate(DeflateParams.DEFAULT))), + ) + .mapN(_ sameElements _) + .assert + } + + val resource: Stream[IO, Byte] = + bracket(IO("foo"))(_ => IO.unit).flatMap { str => + val it = str.iterator + emit { + if (it.hasNext) Some(it.next().toByte) + else None + } + }.unNoneTerminate + + test("FlushingChunkWriter should write a resource") { + val p = resource + (p.compile.toVector.map(_.toArray), DumpingWriter.dump(p)).mapN(_ sameElements _).assert + } + + test("FlushingChunkWriter should write a deflated resource") { + val p = resource.through(Compression[IO].deflate(DeflateParams.DEFAULT)) + ( + p.compile.toVector.map(_.toArray), + DumpingWriter.dump(resource.through(Compression[IO].deflate(DeflateParams.DEFAULT))), + ) + .mapN(_ sameElements _) + .assert + } + + test("FlushingChunkWriter should must be stack safe") { + val p = repeatEval(IO.pure[Byte](0.toByte)).take(300000) + + // The dumping writer is stack safe when using a trampolining EC + (new DumpingWriter).writeEntityBody(p).attempt.map(_.isRight).assert + } + + test("FlushingChunkWriter should Execute cleanup on a failing Http1Writer") { + (for { + clean <- Ref.of[IO, Boolean](false) + p = chunk(messageBuffer).onFinalizeWeak(clean.set(true)) + w <- new FailingWriter().writeEntityBody(p).attempt + c <- clean.get + } yield w.isLeft && c).assert + } + + test( + "FlushingChunkWriter should Execute cleanup on a failing Http1Writer with a failing process" + ) { + (for { + clean <- Ref.of[IO, Boolean](false) + p = eval(IO.raiseError(Failed)).onFinalizeWeak(clean.set(true)) + w <- new FailingWriter().writeEntityBody(p).attempt + c <- clean.get + } yield w.isLeft && c).assert + } + + dispatcher.test("FlushingChunkWriter should Write trailer headers") { implicit d => + def builderWithTrailer(tail: TailStage[ByteBuffer]): FlushingChunkWriter[IO] = + new FlushingChunkWriter[IO]( + tail, + IO.pure(Headers(Header.Raw(ci"X-Trailer", "trailer header value"))), + ) + + val p = eval(IO(messageBuffer)).flatMap(chunk(_)) + + writeEntityBody(p)(builderWithTrailer).assertEquals("""Content-Type: text/plain + |Transfer-Encoding: chunked + | + |c + |Hello world! + |0 + |X-Trailer: trailer header value + | + |""".stripMargin.replace("\n", "\r\n")) + } + +} diff --git a/blaze-core/src/test/scala/org/http4s/blazecore/websocket/Http4sWSStageSpec.scala b/blaze-core/src/test/scala/org/http4s/blazecore/websocket/Http4sWSStageSpec.scala new file mode 100644 index 000000000..833b0e972 --- /dev/null +++ b/blaze-core/src/test/scala/org/http4s/blazecore/websocket/Http4sWSStageSpec.scala @@ -0,0 +1,185 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s.blazecore +package websocket + +import cats.effect.IO +import cats.effect.std.Dispatcher +import cats.effect.std.Queue +import cats.syntax.all._ +import fs2.Stream +import fs2.concurrent.SignallingRef +import munit.CatsEffectSuite +import org.http4s.blaze.pipeline.Command +import org.http4s.blaze.pipeline.LeafBuilder +import org.http4s.websocket.WebSocketFrame +import org.http4s.websocket.WebSocketFrame._ +import org.http4s.websocket.WebSocketSeparatePipe +import scodec.bits.ByteVector + +import java.util.concurrent.atomic.AtomicBoolean +import scala.concurrent.ExecutionContext +import scala.concurrent.duration._ + +class Http4sWSStageSpec extends CatsEffectSuite with DispatcherIOFixture { + implicit val testExecutionContext: ExecutionContext = munitExecutionContext + + class TestWebsocketStage( + outQ: Queue[IO, WebSocketFrame], + head: WSTestHead, + closeHook: AtomicBoolean, + backendInQ: Queue[IO, WebSocketFrame], + ) { + def sendWSOutbound(w: WebSocketFrame*): IO[Unit] = + Stream + .emits(w) + .evalMap(outQ.offer) + .compile + .drain + + def sendInbound(w: WebSocketFrame*): IO[Unit] = + w.toList.traverse(head.put).void + + def pollOutbound(timeoutSeconds: Long = 4L): IO[Option[WebSocketFrame]] = + head.poll(timeoutSeconds) + + def pollBackendInbound(timeoutSeconds: Long = 4L): IO[Option[WebSocketFrame]] = + IO.race(backendInQ.take, IO.sleep(timeoutSeconds.seconds)) + .map(_.fold(Some(_), _ => None)) + + def pollBatchOutputbound(batchSize: Int, timeoutSeconds: Long = 4L): IO[List[WebSocketFrame]] = + head.pollBatch(batchSize, timeoutSeconds) + + val outStream: Stream[IO, WebSocketFrame] = + head.outStream + + def wasCloseHookCalled(): IO[Boolean] = + IO(closeHook.get()) + } + + object TestWebsocketStage { + def apply()(implicit dispatcher: Dispatcher[IO]): IO[TestWebsocketStage] = + for { + outQ <- Queue.unbounded[IO, WebSocketFrame] + backendInQ <- Queue.unbounded[IO, WebSocketFrame] + closeHook = new AtomicBoolean(false) + ws = WebSocketSeparatePipe[IO]( + Stream.repeatEval(outQ.take), + _.evalMap(backendInQ.offer), + IO(closeHook.set(true)), + ) + deadSignal <- SignallingRef[IO, Boolean](false) + wsHead <- WSTestHead() + http4sWSStage <- Http4sWSStage[IO](ws, closeHook, deadSignal, dispatcher) + head = LeafBuilder(http4sWSStage).base(wsHead) + _ <- IO(head.sendInboundCommand(Command.Connected)) + } yield new TestWebsocketStage(outQ, head, closeHook, backendInQ) + } + + dispatcher.test("Http4sWSStage should reply with pong immediately after ping".flaky) { + implicit d => + for { + socket <- TestWebsocketStage() + _ <- socket.sendInbound(Ping()) + p <- socket.pollOutbound(2).map(_.exists(_ == Pong())) + _ <- socket.sendInbound(Close()) + } yield assert(p) + } + + dispatcher.test("Http4sWSStage should not write any more frames after close frame sent") { + implicit d => + for { + socket <- TestWebsocketStage() + _ <- socket.sendWSOutbound(Text("hi"), Close(), Text("lol")) + p1 <- socket.pollOutbound().map(_.contains(Text("hi"))) + p2 <- socket.pollOutbound().map(_.contains(Close())) + p3 <- socket.pollOutbound().map(_.isEmpty) + _ <- socket.sendInbound(Close()) + } yield assert(p1 && p2 && p3) + } + + dispatcher.test( + "Http4sWSStage should send a close frame back and call the on close handler upon receiving a close frame" + ) { implicit d => + for { + socket <- TestWebsocketStage() + _ <- socket.sendInbound(Close()) + p1 <- socket.pollBatchOutputbound(2, 2).map(_ == List(Close())) + p2 <- socket.wasCloseHookCalled().map(_ == true) + } yield assert(p1 && p2) + } + + dispatcher.test("Http4sWSStage should not send two close frames".flaky) { implicit d => + for { + socket <- TestWebsocketStage() + _ <- socket.sendWSOutbound(Close()) + _ <- socket.sendInbound(Close()) + p1 <- socket.pollBatchOutputbound(2).map(_ == List(Close())) + p2 <- socket.wasCloseHookCalled() + } yield assert(p1 && p2) + } + + dispatcher.test("Http4sWSStage should ignore pong frames") { implicit d => + for { + socket <- TestWebsocketStage() + _ <- socket.sendInbound(Pong()) + p <- socket.pollOutbound().map(_.isEmpty) + _ <- socket.sendInbound(Close()) + } yield assert(p) + } + + dispatcher.test("Http4sWSStage should send a ping frames to backend") { implicit d => + for { + socket <- TestWebsocketStage() + _ <- socket.sendInbound(Ping()) + p1 <- socket.pollBackendInbound().map(_.contains(Ping())) + pingWithBytes = Ping(ByteVector(Array[Byte](1, 2, 3))) + _ <- socket.sendInbound(pingWithBytes) + p2 <- socket.pollBackendInbound().map(_.contains(pingWithBytes)) + _ <- socket.sendInbound(Close()) + } yield assert(p1 && p2) + } + + dispatcher.test("Http4sWSStage should send a pong frames to backend") { implicit d => + for { + socket <- TestWebsocketStage() + _ <- socket.sendInbound(Pong()) + p1 <- socket.pollBackendInbound().map(_.contains(Pong())) + pongWithBytes = Pong(ByteVector(Array[Byte](1, 2, 3))) + _ <- socket.sendInbound(pongWithBytes) + p2 <- socket.pollBackendInbound().map(_.contains(pongWithBytes)) + _ <- socket.sendInbound(Close()) + } yield assert(p1 && p2) + } + + dispatcher.test("Http4sWSStage should not fail on pending write request") { implicit d => + for { + socket <- TestWebsocketStage() + reasonSent = ByteVector(42) + in = Stream.eval(socket.sendInbound(Ping())).repeat.take(100) + out = Stream.eval(socket.sendWSOutbound(Text("."))).repeat.take(200) + _ <- in.merge(out).compile.drain + _ <- socket.sendInbound(Close(reasonSent)) + reasonReceived <- + socket.outStream + .collectFirst { case Close(reasonReceived) => reasonReceived } + .compile + .toList + .timeout(5.seconds) + } yield assertEquals(reasonReceived, List(reasonSent)) + } +} diff --git a/blaze-core/src/test/scala/org/http4s/blazecore/websocket/WSTestHead.scala b/blaze-core/src/test/scala/org/http4s/blazecore/websocket/WSTestHead.scala new file mode 100644 index 000000000..60bd9ada2 --- /dev/null +++ b/blaze-core/src/test/scala/org/http4s/blazecore/websocket/WSTestHead.scala @@ -0,0 +1,121 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s.blazecore.websocket + +import cats.effect.IO +import cats.effect.std.Queue +import cats.effect.std.Semaphore +import cats.effect.unsafe.implicits.global +import cats.syntax.all._ +import fs2.Stream +import org.http4s.blaze.pipeline.HeadStage +import org.http4s.websocket.WebSocketFrame + +import scala.concurrent.Future +import scala.concurrent.duration._ + +/** A simple stage to help test websocket requests + * + * This is really disgusting code but bear with me here: + * `java.util.LinkedBlockingDeque` does NOT have nodes with + * atomic references. We need to check finalizers, and those are run concurrently + * and nondeterministically, so we're in a sort of hairy situation + * for checking finalizers doing anything other than waiting on an update + * + * That is, on updates, we may easily run into a lost update problem if + * nodes are checked by a different thread since node values have no + * atomicity guarantee by the jvm. I simply want to provide a (blocking) + * way of reading a websocket frame, to emulate reading from a socket. + */ +sealed abstract class WSTestHead( + inQueue: Queue[IO, WebSocketFrame], + outQueue: Queue[IO, WebSocketFrame], + writeSemaphore: Semaphore[IO], +) extends HeadStage[WebSocketFrame] { + + /** Block while we put elements into our queue + * + * @return + */ + override def readRequest(size: Int): Future[WebSocketFrame] = + inQueue.take.unsafeToFuture() + + /** Sent downstream from the websocket stage, + * put the result in our outqueue, so we may + * pull from it later to inspect it + */ + override def writeRequest(data: WebSocketFrame): Future[Unit] = + writeSemaphore.tryAcquire + .flatMap { + case true => + outQueue.offer(data) *> writeSemaphore.release + case false => + IO.raiseError(new IllegalStateException("pending write")) + } + .unsafeToFuture() + + /** Insert data into the read queue, + * so it's read by the websocket stage + * @param ws + */ + def put(ws: WebSocketFrame): IO[Unit] = + inQueue.offer(ws) + + val outStream: Stream[IO, WebSocketFrame] = + Stream.repeatEval(outQueue.take) + + /** poll our queue for a value, + * timing out after `timeoutSeconds` seconds + * runWorker(this); + */ + def poll(timeoutSeconds: Long): IO[Option[WebSocketFrame]] = + IO.race(IO.sleep(timeoutSeconds.seconds), outQueue.take) + .map { + case Left(_) => None + case Right(wsFrame) => + Some(wsFrame) + } + + def pollBatch(batchSize: Int, timeoutSeconds: Long): IO[List[WebSocketFrame]] = { + def batch(acc: List[WebSocketFrame]): IO[List[WebSocketFrame]] = + if (acc.length == 0) { + outQueue.take.flatMap { frame => + batch(List(frame)) + } + } else if (acc.length < batchSize) { + outQueue.tryTake.flatMap { + case Some(frame) => batch(acc :+ frame) + case None => IO.pure(acc) + } + } else { + IO.pure(acc) + } + + batch(Nil) + .timeoutTo(timeoutSeconds.seconds, IO.pure(Nil)) + } + + override def name: String = "WS test stage" + + override protected def doClosePipeline(cause: Option[Throwable]): Unit = {} +} + +object WSTestHead { + def apply(): IO[WSTestHead] = + (Queue.unbounded[IO, WebSocketFrame], Queue.unbounded[IO, WebSocketFrame], Semaphore[IO](1L)) + .mapN(new WSTestHead(_, _, _) {}) +} diff --git a/blaze-core/src/test/scala/org/http4s/blazecore/websocket/WebSocketHandshakeSuite.scala b/blaze-core/src/test/scala/org/http4s/blazecore/websocket/WebSocketHandshakeSuite.scala new file mode 100644 index 000000000..e8bb554cb --- /dev/null +++ b/blaze-core/src/test/scala/org/http4s/blazecore/websocket/WebSocketHandshakeSuite.scala @@ -0,0 +1,43 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s.blazecore.websocket + +import cats.effect.IO +import cats.effect.std.Random +import munit.CatsEffectSuite + +class WebSocketHandshakeSuite extends CatsEffectSuite { + + test("WebSocketHandshake should Be able to split multi value header keys") { + val totalValue = "keep-alive, Upgrade" + val values = List("upgrade", "Upgrade", "keep-alive", "Keep-alive") + assert(values.forall(v => WebSocketHandshake.valueContains(v, totalValue))) + } + + test("WebSocketHandshake should do a round trip") { + for { + random <- Random.javaSecuritySecureRandom[IO] + client <- WebSocketHandshake.clientHandshaker[IO]("www.foo.com", random) + hs = client.initHeaders + valid = WebSocketHandshake.serverHandshake(hs) + _ = assert(valid.isRight) + response = client.checkResponse(valid.toOption.get) + _ = assert(response.isRight, response.swap.toOption.get) + } yield () + } + +} diff --git a/blaze-server/src/main/scala/org/http4s/blaze/server/BlazeServerBuilder.scala b/blaze-server/src/main/scala/org/http4s/blaze/server/BlazeServerBuilder.scala new file mode 100644 index 000000000..9fa8417ae --- /dev/null +++ b/blaze-server/src/main/scala/org/http4s/blaze/server/BlazeServerBuilder.scala @@ -0,0 +1,590 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s +package blaze +package server + +import cats.Applicative +import cats.data.Kleisli +import cats.effect.Async +import cats.effect.Resource +import cats.effect.Sync +import cats.effect.std.Dispatcher +import cats.syntax.all._ +import com.comcast.ip4s.IpAddress +import com.comcast.ip4s.Port +import com.comcast.ip4s.SocketAddress +import org.http4s.blaze.channel._ +import org.http4s.blaze.channel.nio1.NIO1SocketServerGroup +import org.http4s.blaze.http.http2.server.ALPNServerSelector +import org.http4s.blaze.pipeline.LeafBuilder +import org.http4s.blaze.pipeline.stages.SSLStage +import org.http4s.blaze.server.BlazeServerBuilder._ +import org.http4s.blaze.util.TickWheelExecutor +import org.http4s.blaze.{BuildInfo => BlazeBuildInfo} +import org.http4s.blazecore.BlazeBackendBuilder +import org.http4s.blazecore.ExecutionContextConfig +import org.http4s.blazecore.tickWheelResource +import org.http4s.internal.threads.threadFactory +import org.http4s.internal.tls.deduceKeyLength +import org.http4s.internal.tls.getCertChain +import org.http4s.server.SSLKeyStoreSupport.StoreInfo +import org.http4s.server._ +import org.http4s.server.websocket.WebSocketBuilder2 +import org.http4s.websocket.WebSocketContext +import org.http4s.{BuildInfo => Http4sBuildInfo} +import org.log4s.getLogger +import org.typelevel.vault._ +import scodec.bits.ByteVector + +import java.io.FileInputStream +import java.net.InetSocketAddress +import java.nio.ByteBuffer +import java.security.KeyStore +import java.security.Security +import java.util.concurrent.ThreadFactory +import javax.net.ssl._ +import scala.collection.immutable +import scala.concurrent.ExecutionContext +import scala.concurrent.Future +import scala.concurrent.duration._ + +/** BlazeServerBuilder is the component for the builder pattern aggregating + * different components to finally serve requests. + * + * Variables: + * @param socketAddress: Socket Address the server will be mounted at + * @param responseHeaderTimeout: Time from when the request is made until a + * response line is generated before a 503 response is returned and the + * `HttpApp` is canceled + * @param idleTimeout: Period of Time a connection can remain idle before the + * connection is timed out and disconnected. + * Duration.Inf disables this feature. + * @param connectorPoolSize: Number of worker threads for the new Socket Server Group + * @param bufferSize: Buffer size to use for IO operations + * @param isHttp2Enabled: Whether or not to enable Http2 Server Features + * @param maxRequestLineLen: Maximum request line to parse + * If exceeded returns a 400 Bad Request. + * @param maxHeadersLen: Maximum data that composes the headers. + * If exceeded returns a 400 Bad Request. + * @param chunkBufferMaxSize Size of the buffer that is used when Content-Length header is not specified. + * @param httpApp: The services that are mounted on this server to serve.. + * @param serviceErrorHandler: The last resort to recover and generate a response + * this is necessary to recover totality from the error condition. + * @param banner: Pretty log to display on server start. An empty sequence + * such as Nil disables this + * @param maxConnections: The maximum number of client connections that may be active at any time. + * @param maxWebSocketBufferSize: The maximum Websocket buffer length. 'None' means unbounded. + */ +class BlazeServerBuilder[F[_]] private ( + socketAddress: InetSocketAddress, + executionContextConfig: ExecutionContextConfig, + responseHeaderTimeout: Duration, + idleTimeout: Duration, + connectorPoolSize: Int, + bufferSize: Int, + selectorThreadFactory: ThreadFactory, + sslConfig: SslConfig[F], + isHttp2Enabled: Boolean, + maxRequestLineLen: Int, + maxHeadersLen: Int, + chunkBufferMaxSize: Int, + httpApp: WebSocketBuilder2[F] => HttpApp[F], + serviceErrorHandler: ServiceErrorHandler[F], + banner: immutable.Seq[String], + maxConnections: Int, + val channelOptions: ChannelOptions, + maxWebSocketBufferSize: Option[Int], +)(implicit protected val F: Async[F]) + extends ServerBuilder[F] + with BlazeBackendBuilder[Server] { + type Self = BlazeServerBuilder[F] + + private[this] val logger = getLogger + + private def copy( + socketAddress: InetSocketAddress = socketAddress, + executionContextConfig: ExecutionContextConfig = executionContextConfig, + idleTimeout: Duration = idleTimeout, + responseHeaderTimeout: Duration = responseHeaderTimeout, + connectorPoolSize: Int = connectorPoolSize, + bufferSize: Int = bufferSize, + selectorThreadFactory: ThreadFactory = selectorThreadFactory, + sslConfig: SslConfig[F] = sslConfig, + http2Support: Boolean = isHttp2Enabled, + maxRequestLineLen: Int = maxRequestLineLen, + maxHeadersLen: Int = maxHeadersLen, + chunkBufferMaxSize: Int = chunkBufferMaxSize, + httpApp: WebSocketBuilder2[F] => HttpApp[F] = httpApp, + serviceErrorHandler: ServiceErrorHandler[F] = serviceErrorHandler, + banner: immutable.Seq[String] = banner, + maxConnections: Int = maxConnections, + channelOptions: ChannelOptions = channelOptions, + maxWebSocketBufferSize: Option[Int] = maxWebSocketBufferSize, + ): Self = + new BlazeServerBuilder( + socketAddress, + executionContextConfig, + responseHeaderTimeout, + idleTimeout, + connectorPoolSize, + bufferSize, + selectorThreadFactory, + sslConfig, + http2Support, + maxRequestLineLen, + maxHeadersLen, + chunkBufferMaxSize, + httpApp, + serviceErrorHandler, + banner, + maxConnections, + channelOptions, + maxWebSocketBufferSize, + ) + + /** Configure HTTP parser length limits + * + * These are to avoid denial of service attacks due to, + * for example, an infinite request line. + * + * @param maxRequestLineLen maximum request line to parse + * @param maxHeadersLen maximum data that compose headers + */ + def withLengthLimits( + maxRequestLineLen: Int = maxRequestLineLen, + maxHeadersLen: Int = maxHeadersLen, + ): Self = + copy(maxRequestLineLen = maxRequestLineLen, maxHeadersLen = maxHeadersLen) + + @deprecated( + "Build an `SSLContext` from the first four parameters and use `withSslContext` (note lowercase). To also request client certificates, use `withSslContextAndParameters, calling either `.setWantClientAuth(true)` or `setNeedClientAuth(true)` on the `SSLParameters`.", + "0.21.0-RC3", + ) + def withSSL( + keyStore: StoreInfo, + keyManagerPassword: String, + protocol: String = "TLS", + trustStore: Option[StoreInfo] = None, + clientAuth: SSLClientAuthMode = SSLClientAuthMode.NotRequested, + ): Self = { + val bits = new KeyStoreBits[F](keyStore, keyManagerPassword, protocol, trustStore, clientAuth) + copy(sslConfig = bits) + } + + @deprecated( + "Use `withSslContext` (note lowercase). To request client certificates, use `withSslContextAndParameters, calling either `.setWantClientAuth(true)` or `setNeedClientAuth(true)` on the `SSLParameters`.", + "0.21.0-RC3", + ) + def withSSLContext( + sslContext: SSLContext, + clientAuth: SSLClientAuthMode = SSLClientAuthMode.NotRequested, + ): Self = + copy(sslConfig = new ContextWithClientAuth[F](sslContext, clientAuth)) + + /** Configures the server with TLS, using the provided `SSLContext` and its + * default `SSLParameters` + */ + def withSslContext(sslContext: SSLContext): Self = + copy(sslConfig = new ContextOnly[F](sslContext)) + + /** Configures the server with TLS, using the provided `SSLContext` and `SSLParameters`. */ + def withSslContextAndParameters(sslContext: SSLContext, sslParameters: SSLParameters): Self = + copy(sslConfig = new ContextWithParameters[F](sslContext, sslParameters)) + + def withoutSsl: Self = + copy(sslConfig = new NoSsl[F]()) + + override def bindSocketAddress(socketAddress: InetSocketAddress): Self = + copy(socketAddress = socketAddress) + + /** Configures the compute thread pool used to process some async computations. + * + * This defaults to `cats.effect.Async[F].executionContext`. In + * almost all cases, it is desirable to use the default. + * + * The Blaze server has a single-threaded event loop receiver used + * for picking up tcp connections which is completely separate to + * this pool. Following picking up a tcp connection, Blaze shifts + * to a compute pool to process requests. The request processing + * logic specified by the `HttpApp` is executed on the + * `cats.effect.Async[F].executionContext`. Some of the other async + * computations involved in request processing are executed on this + * pool. + */ + def withExecutionContext(executionContext: ExecutionContext): BlazeServerBuilder[F] = + copy(executionContextConfig = ExecutionContextConfig.ExplicitContext(executionContext)) + + def withIdleTimeout(idleTimeout: Duration): Self = copy(idleTimeout = idleTimeout) + + def withResponseHeaderTimeout(responseHeaderTimeout: Duration): Self = + copy(responseHeaderTimeout = responseHeaderTimeout) + + def withConnectorPoolSize(size: Int): Self = copy(connectorPoolSize = size) + + def withBufferSize(size: Int): Self = copy(bufferSize = size) + + def withSelectorThreadFactory(selectorThreadFactory: ThreadFactory): Self = + copy(selectorThreadFactory = selectorThreadFactory) + + @deprecated("This operation is a no-op. WebSockets are always enabled.", "0.23") + def withWebSockets(enableWebsockets: Boolean): Self = + this + + def enableHttp2(enabled: Boolean): Self = copy(http2Support = enabled) + + def withHttpApp(httpApp: HttpApp[F]): Self = + copy(httpApp = _ => httpApp) + + def withHttpWebSocketApp(f: WebSocketBuilder2[F] => HttpApp[F]): Self = + copy(httpApp = f) + + def withServiceErrorHandler(serviceErrorHandler: ServiceErrorHandler[F]): Self = + copy(serviceErrorHandler = serviceErrorHandler) + + def withBanner(banner: immutable.Seq[String]): Self = + copy(banner = banner) + + def withChannelOptions(channelOptions: ChannelOptions): BlazeServerBuilder[F] = + copy(channelOptions = channelOptions) + + def withMaxRequestLineLength(maxRequestLineLength: Int): BlazeServerBuilder[F] = + copy(maxRequestLineLen = maxRequestLineLength) + + def withMaxHeadersLength(maxHeadersLength: Int): BlazeServerBuilder[F] = + copy(maxHeadersLen = maxHeadersLength) + + def withChunkBufferMaxSize(chunkBufferMaxSize: Int): BlazeServerBuilder[F] = + copy(chunkBufferMaxSize = chunkBufferMaxSize) + + def withMaxConnections(maxConnections: Int): BlazeServerBuilder[F] = + copy(maxConnections = maxConnections) + + def withMaxWebSocketBufferSize(maxWebSocketBufferSize: Option[Int]): BlazeServerBuilder[F] = + copy(maxWebSocketBufferSize = maxWebSocketBufferSize) + + private def pipelineFactory( + scheduler: TickWheelExecutor, + engineConfig: Option[(SSLContext, SSLEngine => Unit)], + dispatcher: Dispatcher[F], + )(conn: SocketConnection): Future[LeafBuilder[ByteBuffer]] = { + def requestAttributes(secure: Boolean, optionalSslEngine: Option[SSLEngine]): () => Vault = + (conn.local, conn.remote) match { + case (local: InetSocketAddress, remote: InetSocketAddress) => + () => { + val connection = Request.Connection( + local = SocketAddress( + IpAddress.fromBytes(local.getAddress.getAddress).get, + Port.fromInt(local.getPort).get, + ), + remote = SocketAddress( + IpAddress.fromBytes(remote.getAddress.getAddress).get, + Port.fromInt(remote.getPort).get, + ), + secure = secure, + ) + + // Create SSLSession object only for https requests and if current SSL session is not empty. + // Here, each condition is checked inside a "flatMap" to handle possible "null" values + def secureSession: Option[SecureSession] = + for { + engine <- optionalSslEngine + session <- Option(engine.getSession) + hex <- Option(session.getId).map(ByteVector(_).toHex) + cipher <- Option(session.getCipherSuite) + } yield SecureSession(hex, cipher, deduceKeyLength(cipher), getCertChain(session)) + + Vault.empty + .insert(Request.Keys.ConnectionInfo, connection) + .insert(ServerRequestKeys.SecureSession, if (secure) secureSession else None) + } + + case _ => + () => Vault.empty + } + + def http1Stage( + executionContext: ExecutionContext, + secure: Boolean, + engine: Option[SSLEngine], + webSocketKey: Key[WebSocketContext[F]], + ) = + Http1ServerStage( + httpApp(WebSocketBuilder2(webSocketKey)), + requestAttributes(secure = secure, engine), + executionContext, + webSocketKey, + maxRequestLineLen, + maxHeadersLen, + chunkBufferMaxSize, + serviceErrorHandler, + responseHeaderTimeout, + idleTimeout, + scheduler, + dispatcher, + maxWebSocketBufferSize, + ) + + def http2Stage( + executionContext: ExecutionContext, + engine: SSLEngine, + webSocketKey: Key[WebSocketContext[F]], + ): ALPNServerSelector = + ProtocolSelector( + engine, + httpApp(WebSocketBuilder2(webSocketKey)), + maxRequestLineLen, + maxHeadersLen, + chunkBufferMaxSize, + requestAttributes(secure = true, engine.some), + executionContext, + serviceErrorHandler, + responseHeaderTimeout, + idleTimeout, + scheduler, + dispatcher, + webSocketKey, + maxWebSocketBufferSize, + ) + + dispatcher.unsafeToFuture { + Key.newKey[F, WebSocketContext[F]].flatMap { wsKey => + executionContextConfig.getExecutionContext[F].map { executionContext => + engineConfig match { + case Some((ctx, configure)) => + val engine = ctx.createSSLEngine() + engine.setUseClientMode(false) + configure(engine) + + LeafBuilder( + if (isHttp2Enabled) http2Stage(executionContext, engine, wsKey) + else http1Stage(executionContext, secure = true, engine.some, wsKey) + ).prepend(new SSLStage(engine)) + + case None => + if (isHttp2Enabled) + logger.warn("HTTP/2 support requires TLS. Falling back to HTTP/1.") + LeafBuilder(http1Stage(executionContext, secure = false, None, wsKey)) + } + } + } + } + } + + def resource: Resource[F, Server] = { + def resolveAddress(address: InetSocketAddress) = + if (address.isUnresolved) new InetSocketAddress(address.getHostName, address.getPort) + else address + + val mkFactory: Resource[F, ServerChannelGroup] = Resource.make(F.delay { + NIO1SocketServerGroup + .fixed( + workerThreads = connectorPoolSize, + bufferSize = bufferSize, + channelOptions = channelOptions, + selectorThreadFactory = selectorThreadFactory, + maxConnections = maxConnections, + ) + })(factory => F.delay(factory.closeGroup())) + + def mkServerChannel( + factory: ServerChannelGroup, + scheduler: TickWheelExecutor, + dispatcher: Dispatcher[F], + ): Resource[F, ServerChannel] = + Resource.make( + for { + ctxOpt <- sslConfig.makeContext + engineCfg = ctxOpt.map(ctx => (ctx, sslConfig.configureEngine _)) + address = resolveAddress(socketAddress) + } yield factory.bind(address, pipelineFactory(scheduler, engineCfg, dispatcher)).get + )(serverChannel => F.delay(serverChannel.close())) + + def logStart(server: Server): Resource[F, Unit] = + Resource.eval(F.delay { + Option(banner) + .filter(_.nonEmpty) + .map(_.mkString("\n", "\n", "")) + .foreach(logger.info(_)) + + logger.info( + s"http4s v${Http4sBuildInfo.version} on blaze v${BlazeBuildInfo.version} started at ${server.baseUri}" + ) + }) + + for { + // blaze doesn't have graceful shutdowns, which means it may continue to submit effects, + // ever after the server has acknowledged shutdown, so we just need to allocate + dispatcher <- Resource.eval(Dispatcher[F].allocated.map(_._1)) + scheduler <- tickWheelResource + + _ <- Resource.eval(verifyTimeoutRelations()) + + factory <- mkFactory + serverChannel <- mkServerChannel(factory, scheduler, dispatcher) + server = new Server { + val address: InetSocketAddress = + serverChannel.socketAddress + + val isSecure = sslConfig.isSecure + + override def toString: String = + s"BlazeServer($address)" + } + + _ <- logStart(server) + } yield server + } + + private def verifyTimeoutRelations(): F[Unit] = + F.delay { + if (responseHeaderTimeout.isFinite && responseHeaderTimeout >= idleTimeout) + logger.warn( + s"responseHeaderTimeout ($responseHeaderTimeout) is >= idleTimeout ($idleTimeout). " + + s"It is recommended to configure responseHeaderTimeout < idleTimeout, " + + s"otherwise timeout responses won't be delivered to clients." + ) + } +} + +object BlazeServerBuilder { + @deprecated( + "Most users should use the default execution context provided. " + + "If you have a specific reason to use a custom one, use `.withExecutionContext`", + "0.23.5", + ) + def apply[F[_]](executionContext: ExecutionContext)(implicit F: Async[F]): BlazeServerBuilder[F] = + apply[F].withExecutionContext(executionContext) + + def apply[F[_]](implicit F: Async[F]): BlazeServerBuilder[F] = + new BlazeServerBuilder( + socketAddress = defaults.IPv4SocketAddress, + executionContextConfig = ExecutionContextConfig.DefaultContext, + responseHeaderTimeout = defaults.ResponseTimeout, + idleTimeout = defaults.IdleTimeout, + connectorPoolSize = DefaultPoolSize, + bufferSize = 64 * 1024, + selectorThreadFactory = defaultThreadSelectorFactory, + sslConfig = new NoSsl[F](), + isHttp2Enabled = false, + maxRequestLineLen = 4 * 1024, + maxHeadersLen = defaults.MaxHeadersSize, + chunkBufferMaxSize = 1024 * 1024, + httpApp = _ => defaultApp[F], + serviceErrorHandler = DefaultServiceErrorHandler[F], + banner = defaults.Banner, + maxConnections = defaults.MaxConnections, + channelOptions = ChannelOptions(Vector.empty), + maxWebSocketBufferSize = None, + ) + + private def defaultApp[F[_]: Applicative]: HttpApp[F] = + Kleisli(_ => Response[F](Status.NotFound).pure[F]) + + private def defaultThreadSelectorFactory: ThreadFactory = + threadFactory(name = n => s"blaze-selector-${n}", daemon = false) + + private sealed trait SslConfig[F[_]] { + def makeContext: F[Option[SSLContext]] + def configureEngine(sslEngine: SSLEngine): Unit + def isSecure: Boolean + } + + private final class KeyStoreBits[F[_]]( + keyStore: StoreInfo, + keyManagerPassword: String, + protocol: String, + trustStore: Option[StoreInfo], + clientAuth: SSLClientAuthMode, + )(implicit F: Sync[F]) + extends SslConfig[F] { + def makeContext: F[Option[SSLContext]] = + F.delay { + val ksStream = new FileInputStream(keyStore.path) + val ks = KeyStore.getInstance("JKS") + ks.load(ksStream, keyStore.password.toCharArray) + ksStream.close() + + val tmf = trustStore.map { auth => + val ksStream = new FileInputStream(auth.path) + + val ks = KeyStore.getInstance("JKS") + ks.load(ksStream, auth.password.toCharArray) + ksStream.close() + + val tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm) + + tmf.init(ks) + tmf.getTrustManagers + } + + val kmf = KeyManagerFactory.getInstance( + Option(Security.getProperty("ssl.KeyManagerFactory.algorithm")) + .getOrElse(KeyManagerFactory.getDefaultAlgorithm) + ) + + kmf.init(ks, keyManagerPassword.toCharArray) + + val context = SSLContext.getInstance(protocol) + context.init(kmf.getKeyManagers, tmf.orNull, null) + context.some + } + def configureEngine(engine: SSLEngine): Unit = + configureEngineFromSslClientAuthMode(engine, clientAuth) + def isSecure: Boolean = true + } + + private class ContextOnly[F[_]](sslContext: SSLContext)(implicit F: Applicative[F]) + extends SslConfig[F] { + def makeContext: F[Option[SSLContext]] = F.pure(sslContext.some) + def configureEngine(engine: SSLEngine): Unit = () + def isSecure: Boolean = true + } + + private class ContextWithParameters[F[_]](sslContext: SSLContext, sslParameters: SSLParameters)( + implicit F: Applicative[F] + ) extends SslConfig[F] { + def makeContext: F[Option[SSLContext]] = F.pure(sslContext.some) + def configureEngine(engine: SSLEngine): Unit = engine.setSSLParameters(sslParameters) + def isSecure: Boolean = true + } + + private class ContextWithClientAuth[F[_]](sslContext: SSLContext, clientAuth: SSLClientAuthMode)( + implicit F: Applicative[F] + ) extends SslConfig[F] { + def makeContext: F[Option[SSLContext]] = F.pure(sslContext.some) + def configureEngine(engine: SSLEngine): Unit = + configureEngineFromSslClientAuthMode(engine, clientAuth) + def isSecure: Boolean = true + } + + private class NoSsl[F[_]]()(implicit F: Applicative[F]) extends SslConfig[F] { + def makeContext: F[Option[SSLContext]] = F.pure(None) + def configureEngine(engine: SSLEngine): Unit = () + def isSecure: Boolean = false + } + + private def configureEngineFromSslClientAuthMode( + engine: SSLEngine, + clientAuthMode: SSLClientAuthMode, + ): Unit = + clientAuthMode match { + case SSLClientAuthMode.Required => engine.setNeedClientAuth(true) + case SSLClientAuthMode.Requested => engine.setWantClientAuth(true) + case SSLClientAuthMode.NotRequested => () + } +} diff --git a/blaze-server/src/main/scala/org/http4s/blaze/server/Http1ServerParser.scala b/blaze-server/src/main/scala/org/http4s/blaze/server/Http1ServerParser.scala new file mode 100644 index 000000000..9a1dda2c0 --- /dev/null +++ b/blaze-server/src/main/scala/org/http4s/blaze/server/Http1ServerParser.scala @@ -0,0 +1,110 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s +package blaze.server + +import cats.effect._ +import cats.syntax.all._ +import org.log4s.Logger +import org.typelevel.vault._ + +import java.nio.ByteBuffer +import scala.collection.mutable.ListBuffer +import scala.util.Either + +private[http4s] final class Http1ServerParser[F[_]]( + logger: Logger, + maxRequestLine: Int, + maxHeadersLen: Int, +)(implicit F: Async[F]) + extends blaze.http.parser.Http1ServerParser(maxRequestLine, maxHeadersLen, 2 * 1024) { + private var uri: String = _ + private var method: String = _ + private var minor: Int = -1 + private val headers = new ListBuffer[Header.ToRaw] + + def minorVersion(): Int = minor + + def doParseRequestLine(buff: ByteBuffer): Boolean = parseRequestLine(buff) + + def doParseHeaders(buff: ByteBuffer): Boolean = parseHeaders(buff) + + def doParseContent(buff: ByteBuffer): Option[ByteBuffer] = Option(parseContent(buff)) + + def collectMessage( + body: EntityBody[F], + attrs: Vault, + ): Either[(ParseFailure, HttpVersion), Request[F]] = { + val h = Headers(headers.result()) + headers.clear() + val protocol = if (minorVersion() == 1) HttpVersion.`HTTP/1.1` else HttpVersion.`HTTP/1.0` + + val attrsWithTrailers = + if (minorVersion() == 1 && isChunked) + attrs.insert( + Message.Keys.TrailerHeaders[F], + F.defer[Headers] { + if (!contentComplete()) + F.raiseError( + new IllegalStateException( + "Attempted to collect trailers before the body was complete." + ) + ) + else F.pure(Headers(headers.result())) + }, + ) + else attrs // Won't have trailers without a chunked body + + Method + .fromString(this.method) + .flatMap { method => + Uri.requestTarget(this.uri).map { uri => + Request(method, uri, protocol, h, body, attrsWithTrailers) + } + } + .leftMap(_ -> protocol) + } + + override def submitRequestLine( + methodString: String, + uri: String, + scheme: String, + majorversion: Int, + minorversion: Int, + ): Boolean = { + logger.trace(s"Received request($methodString $uri $scheme/$majorversion.$minorversion)") + this.uri = uri + this.method = methodString + this.minor = minorversion + false + } + + // ///////////////// Stateful methods for the HTTP parser /////////////////// + override protected def headerComplete(name: String, value: String): Boolean = { + logger.trace(s"Received header '$name: $value'") + headers += name -> value + false + } + + override def reset(): Unit = { + uri = null + method = null + minor = -1 + headers.clear() + super.reset() + } +} diff --git a/blaze-server/src/main/scala/org/http4s/blaze/server/Http1ServerStage.scala b/blaze-server/src/main/scala/org/http4s/blaze/server/Http1ServerStage.scala new file mode 100644 index 000000000..700fe8a0d --- /dev/null +++ b/blaze-server/src/main/scala/org/http4s/blaze/server/Http1ServerStage.scala @@ -0,0 +1,393 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s +package blaze +package server + +import cats.effect.Async +import cats.effect.std.Dispatcher +import cats.effect.syntax.monadCancel._ +import cats.syntax.all._ +import org.http4s.blaze.http.parser.BaseExceptions.BadMessage +import org.http4s.blaze.http.parser.BaseExceptions.ParserException +import org.http4s.blaze.pipeline.Command.EOF +import org.http4s.blaze.pipeline.TailStage +import org.http4s.blaze.pipeline.{Command => Cmd} +import org.http4s.blaze.util.BufferTools +import org.http4s.blaze.util.BufferTools.emptyBuffer +import org.http4s.blaze.util.Execution._ +import org.http4s.blaze.util.TickWheelExecutor +import org.http4s.blazecore.Http1Stage +import org.http4s.blazecore.IdleTimeoutStage +import org.http4s.blazecore.util.BodylessWriter +import org.http4s.blazecore.util.Http1Writer +import org.http4s.headers.Connection +import org.http4s.headers.`Content-Length` +import org.http4s.headers.`Transfer-Encoding` +import org.http4s.server.ServiceErrorHandler +import org.http4s.util.StringWriter +import org.http4s.websocket.WebSocketContext +import org.typelevel.vault._ + +import java.nio.ByteBuffer +import java.util.concurrent.TimeoutException +import scala.concurrent.ExecutionContext +import scala.concurrent.Future +import scala.concurrent.duration.Duration +import scala.concurrent.duration.FiniteDuration +import scala.util.Either +import scala.util.Failure +import scala.util.Left +import scala.util.Right +import scala.util.Success +import scala.util.Try + +private[http4s] object Http1ServerStage { + def apply[F[_]]( + routes: HttpApp[F], + attributes: () => Vault, + executionContext: ExecutionContext, + wsKey: Key[WebSocketContext[F]], + maxRequestLineLen: Int, + maxHeadersLen: Int, + chunkBufferMaxSize: Int, + serviceErrorHandler: ServiceErrorHandler[F], + responseHeaderTimeout: Duration, + idleTimeout: Duration, + scheduler: TickWheelExecutor, + dispatcher: Dispatcher[F], + maxWebSocketBufferSize: Option[Int], + )(implicit F: Async[F]): Http1ServerStage[F] = + new Http1ServerStage( + routes, + attributes, + executionContext, + maxRequestLineLen, + maxHeadersLen, + chunkBufferMaxSize, + serviceErrorHandler, + responseHeaderTimeout, + idleTimeout, + scheduler, + dispatcher, + ) with WebSocketSupport[F] { + val webSocketKey = wsKey + override protected def maxBufferSize: Option[Int] = maxWebSocketBufferSize + } +} + +private[blaze] class Http1ServerStage[F[_]]( + httpApp: HttpApp[F], + requestAttrs: () => Vault, + implicit protected val executionContext: ExecutionContext, + maxRequestLineLen: Int, + maxHeadersLen: Int, + override val chunkBufferMaxSize: Int, + serviceErrorHandler: ServiceErrorHandler[F], + responseHeaderTimeout: Duration, + idleTimeout: Duration, + scheduler: TickWheelExecutor, + val dispatcher: Dispatcher[F], +)(implicit protected val F: Async[F]) + extends Http1Stage[F] + with TailStage[ByteBuffer] { + // micro-optimization: unwrap the routes and call its .run directly + private[this] val runApp = httpApp.run + + // protected by synchronization on `parser` + private[this] val parser = new Http1ServerParser[F](logger, maxRequestLineLen, maxHeadersLen) + private[this] var isClosed = false + @volatile private[this] var cancelToken: Option[() => Future[Unit]] = None + + val name = "Http4sServerStage" + + logger.trace(s"Http4sStage starting up") + + override protected final def doParseContent(buffer: ByteBuffer): Option[ByteBuffer] = + parser.synchronized { + parser.doParseContent(buffer) + } + + override protected final def contentComplete(): Boolean = + parser.synchronized { + parser.contentComplete() + } + + // Will act as our loop + override def stageStartup(): Unit = { + logger.debug("Starting HTTP pipeline") + initIdleTimeout() + requestLoop() + } + + private def initIdleTimeout(): Unit = + idleTimeout match { + case f: FiniteDuration => + val cb: Callback[TimeoutException] = { + case Left(t) => + fatalError(t, "Error in idle timeout callback") + case Right(_) => + logger.debug("Shutting down due to idle timeout") + closePipeline(None) + } + val stage = new IdleTimeoutStage[ByteBuffer](f, scheduler, executionContext) + spliceBefore(stage) + stage.init(cb) + case _ => + } + + private val handleReqRead: Try[ByteBuffer] => Unit = { + case Success(buff) => reqLoopCallback(buff) + case Failure(Cmd.EOF) => closeConnection() + case Failure(t) => fatalError(t, "Error in requestLoop()") + } + + private def requestLoop(): Unit = channelRead().onComplete(handleReqRead)(trampoline) + + private def reqLoopCallback(buff: ByteBuffer): Unit = { + logRequest(buff) + parser.synchronized { + if (!isClosed) + try + if (!parser.requestLineComplete() && !parser.doParseRequestLine(buff)) + requestLoop() + else if (!parser.headersComplete() && !parser.doParseHeaders(buff)) + requestLoop() + else + // we have enough to start the request + runRequest(buff) + catch { + case t: BadMessage => + badMessage("Error parsing status or headers in requestLoop()", t, Request[F]()) + case t: Throwable => + internalServerError( + "error in requestLoop()", + t, + Request[F](), + () => Future.successful(emptyBuffer), + ) + } + } + } + + private def logRequest(buffer: ByteBuffer): Unit = + logger.trace { + val msg = BufferTools + .bufferToString(buffer.duplicate()) + .replace("\r", "\\r") + .replace("\n", "\\n\n") + s"Received Request:\n$msg" + } + + // Only called while holding the monitor of `parser` + private def runRequest(buffer: ByteBuffer): Unit = { + val (body, cleanup) = collectBodyFromParser( + buffer, + () => Either.left(InvalidBodyException("Received premature EOF.")), + ) + + parser.collectMessage(body, requestAttrs()) match { + case Right(req) => + executionContext.execute(new Runnable { + def run(): Unit = { + val action = raceTimeout(req) + .recoverWith(serviceErrorHandler(req)) + .flatMap(resp => F.delay(renderResponse(req, resp, cleanup))) + .attempt + .flatMap { + case Right(_) => F.unit + case Left(t) => + F.delay(logger.error(t)(s"Error running request: $req")) + .guarantee( + F.delay { + cancelToken = None + closeConnection() + } + ) + } + + cancelToken = Some(dispatcher.unsafeToFutureCancelable(action)._2) + } + }) + case Left((e, protocol)) => + badMessage(e.details, new BadMessage(e.sanitized), Request[F]().withHttpVersion(protocol)) + } + } + + protected def renderResponse( + req: Request[F], + resp: Response[F], + bodyCleanup: () => Future[ByteBuffer], + ): Unit = { + val rr = new StringWriter(512) + rr << req.httpVersion << ' ' << resp.status << "\r\n" + + Http1Stage.encodeHeaders(resp.headers.headers, rr, isServer = true) + + val respTransferCoding = resp.headers.get[`Transfer-Encoding`] + val lengthHeader = resp.headers.get[`Content-Length`] + val respConn = resp.headers.get[Connection] + + // Need to decide which encoder and if to close on finish + val closeOnFinish = respConn + .map(_.hasClose) + .orElse { + req.headers.get[Connection].map(checkCloseConnection(_, rr)) + } + .getOrElse( + parser.minorVersion() == 0 + ) // Finally, if nobody specifies, http 1.0 defaults to close + + // choose a body encoder. Will add a Transfer-Encoding header if necessary + val bodyEncoder: Http1Writer[F] = + if (req.method == Method.HEAD || !resp.status.isEntityAllowed) { + // We don't have a body (or don't want to send it) so we just get the headers + + if ( + !resp.status.isEntityAllowed && + (lengthHeader.isDefined || respTransferCoding.isDefined) + ) + logger.warn( + s"Body detected for response code ${resp.status.code} which doesn't permit an entity. Dropping." + ) + + if (req.method == Method.HEAD) + // write message body header for HEAD response + (parser.minorVersion(), respTransferCoding, lengthHeader) match { + case (minor, Some(enc), _) if minor > 0 && enc.hasChunked => + rr << "Transfer-Encoding: chunked\r\n" + case (_, _, Some(len)) => rr << len << "\r\n" + case _ => // nop + } + + // add KeepAlive to Http 1.0 responses if the header isn't already present + rr << (if (!closeOnFinish && parser.minorVersion() == 0 && respConn.isEmpty) + "Connection: keep-alive\r\n\r\n" + else "\r\n") + + new BodylessWriter[F](this, closeOnFinish) + } else + getEncoder( + respConn, + respTransferCoding, + lengthHeader, + resp.trailerHeaders, + rr, + parser.minorVersion(), + closeOnFinish, + false, + ) + + // TODO: pool shifting: https://github.com/http4s/http4s/blob/main/core/src/main/scala/org/http4s/internal/package.scala#L45 + val fa = bodyEncoder + .write(rr, resp.body) + .recover { case EOF => true } + .attempt + .flatMap { + case Right(requireClose) => + if (closeOnFinish || requireClose) { + logger.trace("Request/route requested closing connection.") + F.delay(closeConnection()) + } else + F.delay { + bodyCleanup().onComplete { + case s @ Success(_) => // Serve another request + parser.reset() + handleReqRead(s) + + case Failure(EOF) => closeConnection() + + case Failure(t) => fatalError(t, "Failure in body cleanup") + }(trampoline) + } + case Left(t) => + logger.error(t)("Error writing body") + F.delay(closeConnection()) + } + + dispatcher.unsafeRunAndForget(fa) + + () + } + + private def closeConnection(): Unit = { + logger.debug("closeConnection()") + stageShutdown() + closePipeline(None) + } + + override protected def stageShutdown(): Unit = { + logger.debug("Shutting down HttpPipeline") + parser.synchronized { + cancel() + isClosed = true + parser.shutdownParser() + } + super.stageShutdown() + } + + private def cancel(): Unit = + cancelToken.foreach(_().onComplete { + case Success(_) => + () + case Failure(t) => + logger.warn(t)(s"Error canceling request. No request details are available.") + }) + + protected final def badMessage( + debugMessage: String, + t: ParserException, + req: Request[F], + ): Unit = { + logger.debug(t)(s"Bad Request: $debugMessage") + val resp = Response[F](Status.BadRequest) + .withHeaders(Connection.close, `Content-Length`.zero) + renderResponse(req, resp, () => Future.successful(emptyBuffer)) + } + + // The error handler of last resort + protected final def internalServerError( + errorMsg: String, + t: Throwable, + req: Request[F], + bodyCleanup: () => Future[ByteBuffer], + ): Unit = { + logger.error(t)(errorMsg) + val resp = Response[F](Status.InternalServerError) + .withHeaders(Connection.close, `Content-Length`.zero) + renderResponse( + req, + resp, + bodyCleanup, + ) // will terminate the connection due to connection: close header + } + + private[this] val raceTimeout: Request[F] => F[Response[F]] = + responseHeaderTimeout match { + case finite: FiniteDuration => + val timeoutResponse = F.async[Response[F]] { cb => + F.delay { + val cancellable = + scheduler.schedule(() => cb(Right(Response.timeout[F])), executionContext, finite) + Some(F.delay(cancellable.cancel())) + } + } + req => F.race(runApp(req), timeoutResponse).map(_.merge) + case _ => + runApp + } +} diff --git a/blaze-server/src/main/scala/org/http4s/blaze/server/Http2NodeStage.scala b/blaze-server/src/main/scala/org/http4s/blaze/server/Http2NodeStage.scala new file mode 100644 index 000000000..a55a621a9 --- /dev/null +++ b/blaze-server/src/main/scala/org/http4s/blaze/server/Http2NodeStage.scala @@ -0,0 +1,292 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s +package blaze +package server + +import cats.effect.Async +import cats.effect.std.Dispatcher +import cats.effect.syntax.temporal._ +import cats.syntax.all._ +import fs2.Stream._ +import fs2._ +import org.http4s.blaze.http.HeaderNames +import org.http4s.blaze.http.Headers +import org.http4s.blaze.http.http2._ +import org.http4s.blaze.pipeline.TailStage +import org.http4s.blaze.pipeline.{Command => Cmd} +import org.http4s.blaze.util.TickWheelExecutor +import org.http4s.blazecore.IdleTimeoutStage +import org.http4s.blazecore.util.End +import org.http4s.blazecore.util.Http2Writer +import org.http4s.server.ServiceErrorHandler +import org.http4s.{Method => HMethod} +import org.typelevel.vault._ + +import java.util.Locale +import java.util.concurrent.TimeoutException +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.ListBuffer +import scala.concurrent.ExecutionContext +import scala.concurrent.duration.Duration +import scala.concurrent.duration.FiniteDuration +import scala.util._ + +private class Http2NodeStage[F[_]]( + streamId: Int, + timeout: Duration, + implicit private val executionContext: ExecutionContext, + attributes: () => Vault, + httpApp: HttpApp[F], + serviceErrorHandler: ServiceErrorHandler[F], + responseHeaderTimeout: Duration, + idleTimeout: Duration, + scheduler: TickWheelExecutor, + dispatcher: Dispatcher[F], +)(implicit F: Async[F]) + extends TailStage[StreamFrame] { + // micro-optimization: unwrap the service and call its .run directly + private[this] val runApp = httpApp.run + + override def name = "Http2NodeStage" + + override protected def stageStartup(): Unit = { + super.stageStartup() + initIdleTimeout() + readHeaders() + } + + private def initIdleTimeout(): Unit = + idleTimeout match { + case f: FiniteDuration => + val cb: Callback[TimeoutException] = { + case Left(t) => + logger.error(t)("Error in idle timeout callback") + closePipeline(Some(t)) + case Right(_) => + logger.debug("Shutting down due to idle timeout") + closePipeline(None) + } + val stage = new IdleTimeoutStage[StreamFrame](f, scheduler, executionContext) + spliceBefore(stage) + stage.init(cb) + case _ => + } + + private def readHeaders(): Unit = + channelRead(timeout = timeout).onComplete { + case Success(HeadersFrame(_, endStream, hs)) => + checkAndRunRequest(hs, endStream) + + case Success(frame) => + val e = Http2Exception.PROTOCOL_ERROR.rst(streamId, s"Received invalid frame: $frame") + closePipeline(Some(e)) + + case Failure(Cmd.EOF) => + closePipeline(None) + + case Failure(t) => + logger.error(t)("Unknown error in readHeaders") + val e = Http2Exception.INTERNAL_ERROR.rst(streamId, s"Unknown error") + closePipeline(Some(e)) + } + + /** collect the body: a maxlen < 0 is interpreted as undefined */ + private def getBody(maxlen: Long): EntityBody[F] = { + var complete = false + var bytesRead = 0L + + val t = F.async[Option[Chunk[Byte]]] { cb => + F.delay { + if (complete) cb(End) + else + channelRead(timeout = timeout).onComplete { + case Success(DataFrame(last, bytes)) => + complete = last + bytesRead += bytes.remaining() + + // Check length: invalid length is a stream error of type PROTOCOL_ERROR + // https://datatracker.ietf.org/doc/html/draft-ietf-httpbis-http2-17#section-8.1.2 -> 8.2.1.6 + if (complete && maxlen > 0 && bytesRead != maxlen) { + val msg = s"Entity too small. Expected $maxlen, received $bytesRead" + val e = Http2Exception.PROTOCOL_ERROR.rst(streamId, msg) + closePipeline(Some(e)) + cb(Either.left(InvalidBodyException(msg))) + } else if (maxlen > 0 && bytesRead > maxlen) { + val msg = s"Entity too large. Expected $maxlen, received bytesRead" + val e = Http2Exception.PROTOCOL_ERROR.rst(streamId, msg) + closePipeline(Some(e)) + cb(Either.left(InvalidBodyException(msg))) + } else cb(Either.right(Some(Chunk.array(bytes.array)))) + + case Success(HeadersFrame(_, true, ts)) => + logger.warn("Discarding trailers: " + ts) + cb(Either.right(Some(Chunk.empty))) + + case Success(other) => // This should cover it + val msg = "Received invalid frame while accumulating body: " + other + logger.info(msg) + val e = Http2Exception.PROTOCOL_ERROR.rst(streamId, msg) + closePipeline(Some(e)) + cb(Either.left(InvalidBodyException(msg))) + + case Failure(Cmd.EOF) => + logger.debug("EOF while accumulating body") + cb(Either.left(InvalidBodyException("Received premature EOF."))) + closePipeline(None) + + case Failure(t) => + logger.error(t)("Error in getBody().") + val e = Http2Exception.INTERNAL_ERROR.rst(streamId, "Failed to read body") + cb(Either.left(e)) + closePipeline(Some(e)) + } + + None + } + } + + repeatEval(t).unNoneTerminate.flatMap(chunk(_)) + } + + private def checkAndRunRequest(hs: Headers, endStream: Boolean): Unit = { + val headers = new ListBuffer[Header.ToRaw] + var method: HMethod = null + var scheme: String = null + var path: Uri = null + var contentLength: Long = -1 + var error: String = "" + var pseudoDone = false + + hs.foreach { + case (PseudoHeaders.Method, v) => + if (pseudoDone) error += "Pseudo header in invalid position. " + else if (method == null) org.http4s.Method.fromString(v) match { + case Right(m) => method = m + case Left(e) => error = s"$error Invalid method: $e " + } + else error += "Multiple ':method' headers defined. " + + case (PseudoHeaders.Scheme, v) => + if (pseudoDone) error += "Pseudo header in invalid position. " + else if (scheme == null) scheme = v + else error += "Multiple ':scheme' headers defined. " + + case (PseudoHeaders.Path, v) => + if (pseudoDone) error += "Pseudo header in invalid position. " + else if (path == null) Uri.requestTarget(v) match { + case Right(p) => path = p + case Left(e) => error = s"$error Invalid path: $e" + } + else error += "Multiple ':path' headers defined. " + + case (PseudoHeaders.Authority, _) => // NOOP; TODO: we should keep the authority header + if (pseudoDone) error += "Pseudo header in invalid position. " + + case h @ (k, _) if k.startsWith(":") => error += s"Invalid pseudo header: $h. " + case (k, _) if !HeaderNames.validH2HeaderKey(k) => error += s"Invalid header key: $k. " + + case hs => // Non pseudo headers + pseudoDone = true + hs match { + case h @ (HeaderNames.Connection, _) => + error += s"HTTP/2 forbids connection specific headers: $h. " + + case (HeaderNames.ContentLength, v) => + if (contentLength < 0) try { + val sz = java.lang.Long.parseLong(v) + if (sz != 0 && endStream) error += s"Nonzero content length ($sz) for end of stream." + else if (sz < 0) error += s"Negative content length: $sz" + else contentLength = sz + } catch { + case _: NumberFormatException => error += s"Invalid content-length: $v. " + } + else + error += "Received multiple content-length headers" + + case (HeaderNames.TE, v) => + if (!v.equalsIgnoreCase("trailers")) + error += s"HTTP/2 forbids TE header values other than 'trailers'. " + // ignore otherwise + + case (k, v) => headers += k -> v + } + } + + if (method == null || scheme == null || path == null) + error += s"Invalid request: missing pseudo headers. Method: $method, Scheme: $scheme, path: $path. " + + if (error.length > 0) + closePipeline(Some(Http2Exception.PROTOCOL_ERROR.rst(streamId, error))) + else { + val body = if (endStream) EmptyBody else getBody(contentLength) + val hs = Headers(headers.result()) + val req = Request(method, path, HttpVersion.`HTTP/2`, hs, body, attributes()) + executionContext.execute(new Runnable { + def run(): Unit = { + val action = F + .defer(raceTimeout(req)) + .recoverWith(serviceErrorHandler(req)) + .flatMap(renderResponse(_)) + + val fa = action.attempt.flatMap { + case Right(_) => F.unit + case Left(t) => + F.delay(logger.error(t)(s"Error running request: $req")).attempt *> F.delay( + closePipeline(None) + ) + } + + dispatcher.unsafeRunSync(fa) + + () + } + }) + } + } + + private def renderResponse(resp: Response[F]): F[Unit] = { + val hs = new ArrayBuffer[(String, String)](16) + hs += PseudoHeaders.Status -> Integer.toString(resp.status.code) + resp.headers.foreach { h => + // Connection related headers must be removed from the message because + // this information is conveyed by other means. + // http://httpwg.org/specs/rfc7540.html#rfc.section.8.1.2 + if ( + h.name != headers.`Transfer-Encoding`.name && + h.name != Header[headers.Connection].name + ) { + hs += ((h.name.toString.toLowerCase(Locale.ROOT), h.value)) + () + } + } + + new Http2Writer(this, hs).writeEntityBody(resp.body).attempt.map { + case Right(_) => closePipeline(None) + case Left(Cmd.EOF) => stageShutdown() + case Left(t) => closePipeline(Some(t)) + } + } + + private[this] val raceTimeout: Request[F] => F[Response[F]] = + responseHeaderTimeout match { + case finite: FiniteDuration => + req => runApp(req).timeoutTo(finite, F.pure(Response.timeout[F])) + case _ => + runApp + } +} diff --git a/blaze-server/src/main/scala/org/http4s/blaze/server/ProtocolSelector.scala b/blaze-server/src/main/scala/org/http4s/blaze/server/ProtocolSelector.scala new file mode 100644 index 000000000..5ebe21884 --- /dev/null +++ b/blaze-server/src/main/scala/org/http4s/blaze/server/ProtocolSelector.scala @@ -0,0 +1,121 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s +package blaze +package server + +import cats.effect.Async +import cats.effect.std.Dispatcher +import org.http4s.blaze.http.http2.DefaultFlowStrategy +import org.http4s.blaze.http.http2.Http2Settings +import org.http4s.blaze.http.http2.server.ALPNServerSelector +import org.http4s.blaze.http.http2.server.ServerPriorKnowledgeHandshaker +import org.http4s.blaze.pipeline.LeafBuilder +import org.http4s.blaze.pipeline.TailStage +import org.http4s.blaze.util.TickWheelExecutor +import org.http4s.server.ServiceErrorHandler +import org.http4s.websocket.WebSocketContext +import org.typelevel.vault._ + +import java.nio.ByteBuffer +import javax.net.ssl.SSLEngine +import scala.concurrent.ExecutionContext +import scala.concurrent.duration.Duration + +/** Facilitates the use of ALPN when using blaze http2 support */ +private[http4s] object ProtocolSelector { + def apply[F[_]]( + engine: SSLEngine, + httpApp: HttpApp[F], + maxRequestLineLen: Int, + maxHeadersLen: Int, + chunkBufferMaxSize: Int, + requestAttributes: () => Vault, + executionContext: ExecutionContext, + serviceErrorHandler: ServiceErrorHandler[F], + responseHeaderTimeout: Duration, + idleTimeout: Duration, + scheduler: TickWheelExecutor, + dispatcher: Dispatcher[F], + webSocketKey: Key[WebSocketContext[F]], + maxWebSocketBufferSize: Option[Int], + )(implicit F: Async[F]): ALPNServerSelector = { + def http2Stage(): TailStage[ByteBuffer] = { + val newNode = { (streamId: Int) => + LeafBuilder( + new Http2NodeStage( + streamId, + Duration.Inf, + executionContext, + requestAttributes, + httpApp, + serviceErrorHandler, + responseHeaderTimeout, + idleTimeout, + scheduler, + dispatcher, + ) + ) + } + + val localSettings = + Http2Settings.default.copy( + maxConcurrentStreams = 100, // TODO: configurable? + maxHeaderListSize = maxHeadersLen, + ) + + new ServerPriorKnowledgeHandshaker( + localSettings = localSettings, + flowStrategy = new DefaultFlowStrategy(localSettings), + nodeBuilder = newNode, + ) + } + + def http1Stage(): TailStage[ByteBuffer] = + Http1ServerStage[F]( + httpApp, + requestAttributes, + executionContext, + wsKey = webSocketKey, + maxRequestLineLen, + maxHeadersLen, + chunkBufferMaxSize, + serviceErrorHandler, + responseHeaderTimeout, + idleTimeout, + scheduler, + dispatcher, + maxWebSocketBufferSize, + ) + + def preference(protos: Set[String]): String = + protos + .find { + case "h2" | "h2-14" | "h2-15" => true + case _ => false + } + .getOrElse("undefined") + + def select(s: String): LeafBuilder[ByteBuffer] = + LeafBuilder(s match { + case "h2" | "h2-14" | "h2-15" => http2Stage() + case _ => http1Stage() + }) + + new ALPNServerSelector(engine, preference, select) + } +} diff --git a/blaze-server/src/main/scala/org/http4s/blaze/server/SSLContextFactory.scala b/blaze-server/src/main/scala/org/http4s/blaze/server/SSLContextFactory.scala new file mode 100644 index 000000000..d313a8e70 --- /dev/null +++ b/blaze-server/src/main/scala/org/http4s/blaze/server/SSLContextFactory.scala @@ -0,0 +1,29 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s.blaze.server + +import java.security.cert.X509Certificate +import javax.net.ssl.SSLSession + +@deprecated("Moved to org.http4s.internal.tls", "0.21.19") +private[http4s] object SSLContextFactory { + def getCertChain(sslSession: SSLSession): List[X509Certificate] = + org.http4s.internal.tls.getCertChain(sslSession) + + def deduceKeyLength(cipherSuite: String): Int = + org.http4s.internal.tls.deduceKeyLength(cipherSuite) +} diff --git a/blaze-server/src/main/scala/org/http4s/blaze/server/WSFrameAggregator.scala b/blaze-server/src/main/scala/org/http4s/blaze/server/WSFrameAggregator.scala new file mode 100644 index 000000000..9cec98613 --- /dev/null +++ b/blaze-server/src/main/scala/org/http4s/blaze/server/WSFrameAggregator.scala @@ -0,0 +1,159 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s.blaze.server + +import org.http4s.blaze.pipeline.MidStage +import org.http4s.blaze.server.WSFrameAggregator.Accumulator +import org.http4s.blaze.util.Execution._ +import org.http4s.internal.bug +import org.http4s.websocket.WebSocketFrame +import org.http4s.websocket.WebSocketFrame._ +import scodec.bits.ByteVector + +import java.net.ProtocolException +import scala.annotation.tailrec +import scala.collection.mutable +import scala.concurrent.Future +import scala.concurrent.Promise +import scala.util.Failure +import scala.util.Success + +private class WSFrameAggregator extends MidStage[WebSocketFrame, WebSocketFrame] { + def name: String = "WebSocket Frame Aggregator" + + private[this] val accumulator = new Accumulator + + def readRequest(size: Int): Future[WebSocketFrame] = { + val p = Promise[WebSocketFrame]() + channelRead(size).onComplete { + case Success(f) => readLoop(f, p) + case Failure(t) => p.failure(t) + }(directec) + p.future + } + + private def readLoop(frame: WebSocketFrame, p: Promise[WebSocketFrame]): Unit = + frame match { + case _: Text => handleHead(frame, p) + case _: Binary => handleHead(frame, p) + + case c: Continuation => + if (accumulator.isEmpty) { + val e = new ProtocolException( + "Invalid state: Received a Continuation frame without accumulated state." + ) + logger.error(e)("Invalid state") + p.failure(e) + () + } else { + accumulator.append(frame) + if (c.last) { + // We are finished with the segment, accumulate + p.success(accumulator.take()) + () + } else + channelRead().onComplete { + case Success(f) => + readLoop(f, p) + case Failure(t) => + p.failure(t) + () + }(trampoline) + } + + case f => + // Must be a control frame, send it out + p.success(f) + () + } + + private def handleHead(frame: WebSocketFrame, p: Promise[WebSocketFrame]): Unit = + if (!accumulator.isEmpty) { + val e = new ProtocolException(s"Invalid state: Received a head frame with accumulated state") + accumulator.clear() + p.failure(e) + () + } else if (frame.last) { + // Head frame that is complete + p.success(frame) + () + } else { + // Need to start aggregating + accumulator.append(frame) + channelRead().onComplete { + case Success(f) => + readLoop(f, p) + case Failure(t) => + p.failure(t) + () + }(directec) + } + + // Just forward write requests + def writeRequest(data: WebSocketFrame): Future[Unit] = channelWrite(data) + override def writeRequest(data: collection.Seq[WebSocketFrame]): Future[Unit] = channelWrite(data) +} + +private object WSFrameAggregator { + private final class Accumulator { + private[this] val queue = new mutable.Queue[WebSocketFrame] + private[this] var size = 0 + + def isEmpty: Boolean = queue.isEmpty + + def append(frame: WebSocketFrame): Unit = { + // The first frame needs to not be a continuation + if (queue.isEmpty) frame match { + case _: Text | _: Binary => // nop + case f => + throw bug(s"Shouldn't get here. Wrong type: ${f.getClass.getName}") + } + size += frame.length + queue += frame + () + } + + def take(): WebSocketFrame = { + val isText = queue.head match { + case _: Text => true + case _: Binary => false + case f => + // shouldn't happen as it's guarded for in `append` + val e = bug(s"Shouldn't get here. Wrong type: ${f.getClass.getName}") + throw e + } + + var out = ByteVector.empty + @tailrec + def go(): Unit = + if (!queue.isEmpty) { + val frame = queue.dequeue().data + out ++= frame + go() + } + go() + + size = 0 + if (isText) Text(out) else Binary(out) + } + + def clear(): Unit = { + size = 0 + queue.clear() + } + } +} diff --git a/blaze-server/src/main/scala/org/http4s/blaze/server/WebSocketDecoder.scala b/blaze-server/src/main/scala/org/http4s/blaze/server/WebSocketDecoder.scala new file mode 100644 index 000000000..78f6c8f36 --- /dev/null +++ b/blaze-server/src/main/scala/org/http4s/blaze/server/WebSocketDecoder.scala @@ -0,0 +1,48 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s.blaze.server + +import org.http4s.blaze.pipeline.stages.ByteToObjectStage +import org.http4s.websocket.FrameTranscoder +import org.http4s.websocket.FrameTranscoder.TranscodeError +import org.http4s.websocket.WebSocketFrame + +import java.net.ProtocolException +import java.nio.ByteBuffer + +private class WebSocketDecoder(val maxBufferSize: Int = 0) // unbounded + extends FrameTranscoder(isClient = false) + with ByteToObjectStage[WebSocketFrame] { + + val name = "Websocket Decoder" + + /** Encode objects to buffers + * @param in object to decode + * @return sequence of ByteBuffers to pass to the head + */ + @throws[TranscodeError] + def messageToBuffer(in: WebSocketFrame): collection.Seq[ByteBuffer] = frameToBuffer(in) + + /** Method that decodes ByteBuffers to objects. None reflects not enough data to decode a message + * Any unused data in the ByteBuffer will be recycled and available for the next read + * @param in ByteBuffer of immediately available data + * @return optional message if enough data was available + */ + @throws[TranscodeError] + @throws[ProtocolException] + def bufferToMessage(in: ByteBuffer): Option[WebSocketFrame] = Option(bufferToFrame(in)) +} diff --git a/blaze-server/src/main/scala/org/http4s/blaze/server/WebSocketSupport.scala b/blaze-server/src/main/scala/org/http4s/blaze/server/WebSocketSupport.scala new file mode 100644 index 000000000..19043b99c --- /dev/null +++ b/blaze-server/src/main/scala/org/http4s/blaze/server/WebSocketSupport.scala @@ -0,0 +1,127 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s.blaze.server + +import cats.effect._ +import cats.effect.std.Dispatcher +import cats.effect.std.Semaphore +import cats.syntax.all._ +import fs2.concurrent.SignallingRef +import org.http4s._ +import org.http4s.blaze.pipeline.LeafBuilder +import org.http4s.blazecore.websocket.Http4sWSStage +import org.http4s.blazecore.websocket.WebSocketHandshake +import org.http4s.headers._ +import org.http4s.websocket.WebSocketContext +import org.typelevel.vault.Key + +import java.nio.ByteBuffer +import java.nio.charset.StandardCharsets._ +import java.util.concurrent.atomic.AtomicBoolean +import scala.concurrent.Future +import scala.util.Failure +import scala.util.Success + +private[http4s] trait WebSocketSupport[F[_]] extends Http1ServerStage[F] { + implicit protected val F: Async[F] + + protected def webSocketKey: Key[WebSocketContext[F]] + + implicit val dispatcher: Dispatcher[F] + + protected def maxBufferSize: Option[Int] + + override protected def renderResponse( + req: Request[F], + resp: Response[F], + cleanup: () => Future[ByteBuffer], + ): Unit = { + val ws = resp.attributes.lookup(webSocketKey) + logger.debug(s"Websocket key: $ws\nRequest headers: " + req.headers) + + ws match { + case None => super.renderResponse(req, resp, cleanup) + case Some(wsContext) => + val hdrs = req.headers.headers.map(h => (h.name.toString, h.value)) + if (WebSocketHandshake.isWebSocketRequest(hdrs)) + WebSocketHandshake.serverHandshake(hdrs) match { + case Left((code, msg)) => + logger.info(s"Invalid handshake $code, $msg") + val fa = + wsContext.failureResponse + .map( + _.withHeaders( + Connection.close, + "Sec-WebSocket-Version" -> "13", + ) + ) + .attempt + .flatMap { + case Right(resp) => + F.delay(super.renderResponse(req, resp, cleanup)) + case Left(_) => + F.unit + } + + dispatcher.unsafeRunAndForget(fa) + + () + + case Right(hdrs) => // Successful handshake + val sb = new StringBuilder + sb.append("HTTP/1.1 101 Switching Protocols\r\n") + hdrs.foreach { case (k, v) => + sb.append(k).append(": ").append(v).append('\r').append('\n') + } + + wsContext.headers.foreach { hdr => + sb.append(hdr.name).append(": ").append(hdr.value).append('\r').append('\n') + () + } + + sb.append('\r').append('\n') + + // write the accept headers and reform the pipeline + channelWrite(ByteBuffer.wrap(sb.result().getBytes(ISO_8859_1))).onComplete { + case Success(_) => + logger.debug("Switching pipeline segments for websocket") + + val deadSignal = dispatcher.unsafeRunSync(SignallingRef[F, Boolean](false)) + val writeSemaphore = dispatcher.unsafeRunSync(Semaphore[F](1L)) + val sentClose = new AtomicBoolean(false) + val segment = + LeafBuilder( + new Http4sWSStage[F]( + wsContext.webSocket, + sentClose, + deadSignal, + writeSemaphore, + dispatcher, + ) + ) // TODO: there is a constructor + .prepend(new WSFrameAggregator) + .prepend(new WebSocketDecoder(maxBufferSize.getOrElse(0))) + + this.replaceTail(segment, true) + + case Failure(t) => fatalError(t, "Error writing Websocket upgrade response") + }(executionContext) + } + else super.renderResponse(req, resp, cleanup) + } + } +} diff --git a/blaze-server/src/main/scala/org/http4s/server/blaze/package.scala b/blaze-server/src/main/scala/org/http4s/server/blaze/package.scala new file mode 100644 index 000000000..9b6f2068a --- /dev/null +++ b/blaze-server/src/main/scala/org/http4s/server/blaze/package.scala @@ -0,0 +1,25 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s.server + +package object blaze { + @deprecated("use org.http4s.blaze.server.BlazeServerBuilder", "0.22") + type BlazeServerBuilder[F[_]] = org.http4s.blaze.server.BlazeServerBuilder[F] + + @deprecated("use org.http4s.blaze.server.BlazeServerBuilder", "0.22") + val BlazeServerBuilder = org.http4s.blaze.server.BlazeServerBuilder +} diff --git a/blaze-server/src/test/resources/keystore.jks b/blaze-server/src/test/resources/keystore.jks new file mode 100644 index 000000000..aae876cdc Binary files /dev/null and b/blaze-server/src/test/resources/keystore.jks differ diff --git a/blaze-server/src/test/scala/org/http4s/blaze/server/AutoClosableResource.scala b/blaze-server/src/test/scala/org/http4s/blaze/server/AutoClosableResource.scala new file mode 100644 index 000000000..a4fff9bf6 --- /dev/null +++ b/blaze-server/src/test/scala/org/http4s/blaze/server/AutoClosableResource.scala @@ -0,0 +1,51 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s.blaze.server + +private[http4s] object AutoCloseableResource { + + // TODO: Consider using [[munit.CatsEffectFixtures]] or [[cats.effect.Resource.fromAutoCloseable]] instead + /** Performs an operation using a resource, and then releases the resource, + * even if the operation throws an exception. This method behaves similarly + * to Java's try-with-resources. + * Ported from the Scala's 2.13 [[scala.util.Using.resource]]. + * + * @param resource the resource + * @param body the operation to perform with the resource + * @tparam R the type of the resource + * @tparam A the return type of the operation + * @return the result of the operation, if neither the operation nor + * releasing the resource throws + */ + private[http4s] def resource[R <: AutoCloseable, A](resource: R)(body: R => A): A = { + if (resource == null) throw new NullPointerException("null resource") + + var toThrow: Throwable = null + + try body(resource) + catch { + case t: Throwable => + toThrow = t + null.asInstanceOf[A] + } finally + if (toThrow eq null) resource.close() + else { + try resource.close() + finally throw toThrow + } + } +} diff --git a/blaze-server/src/test/scala/org/http4s/blaze/server/BlazeServerMtlsSpec.scala b/blaze-server/src/test/scala/org/http4s/blaze/server/BlazeServerMtlsSpec.scala new file mode 100644 index 000000000..8cd3da35b --- /dev/null +++ b/blaze-server/src/test/scala/org/http4s/blaze/server/BlazeServerMtlsSpec.scala @@ -0,0 +1,170 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s.blaze.server + +import cats.effect.IO +import cats.effect.Resource +import fs2.io.net.tls.TLSParameters +import munit.CatsEffectSuite +import org.http4s.HttpApp +import org.http4s.dsl.io._ +import org.http4s.server.Server +import org.http4s.server.ServerRequestKeys + +import java.net.URL +import java.nio.charset.StandardCharsets +import java.security.KeyStore +import javax.net.ssl._ +import scala.concurrent.duration._ +import scala.io.Source +import scala.util.Try + +/** Test cases for mTLS support in blaze server + */ +class BlazeServerMtlsSpec extends CatsEffectSuite { + { + val hostnameVerifier: HostnameVerifier = new HostnameVerifier { + override def verify(s: String, sslSession: SSLSession): Boolean = true + } + + // For test cases, don't do any host name verification. Certificates are self-signed and not available to all hosts + HttpsURLConnection.setDefaultHostnameVerifier(hostnameVerifier) + } + + def builder: BlazeServerBuilder[IO] = + BlazeServerBuilder[IO] + .withResponseHeaderTimeout(1.second) + + val service: HttpApp[IO] = HttpApp { + case req @ GET -> Root / "dummy" => + val output = req.attributes + .lookup(ServerRequestKeys.SecureSession) + .flatten + .map { session => + assertNotEquals(session.sslSessionId, "") + assertNotEquals(session.cipherSuite, "") + assertNotEquals(session.keySize, 0) + + session.X509Certificate.head.getSubjectX500Principal.getName + } + .getOrElse("Invalid") + + Ok(output) + + case req @ GET -> Root / "noauth" => + req.attributes + .lookup(ServerRequestKeys.SecureSession) + .flatten + .foreach { session => + assertNotEquals(session.sslSessionId, "") + assertNotEquals(session.cipherSuite, "") + assertNotEquals(session.keySize, 0) + assertEquals(session.X509Certificate, Nil) + } + + Ok("success") + + case _ => NotFound() + } + + def serverR(sslParameters: SSLParameters): Resource[IO, Server] = + builder + .bindAny() + .withSslContextAndParameters(sslContext, sslParameters) + .withHttpApp(service) + .resource + + lazy val sslContext: SSLContext = { + val ks = KeyStore.getInstance("JKS") + ks.load(getClass.getResourceAsStream("/keystore.jks"), "password".toCharArray) + + val kmf = KeyManagerFactory.getInstance("SunX509") + kmf.init(ks, "password".toCharArray) + + val js = KeyStore.getInstance("JKS") + js.load(getClass.getResourceAsStream("/keystore.jks"), "password".toCharArray) + + val tmf = TrustManagerFactory.getInstance("SunX509") + tmf.init(js) + + val sc = SSLContext.getInstance("TLSv1.2") + sc.init(kmf.getKeyManagers, tmf.getTrustManagers, null) + + sc + } + + /** Used for no mTLS client. Required to trust self-signed certificate. + */ + lazy val noAuthClientContext: SSLContext = { + val js = KeyStore.getInstance("JKS") + js.load(getClass.getResourceAsStream("/keystore.jks"), "password".toCharArray) + + val tmf = TrustManagerFactory.getInstance("SunX509") + tmf.init(js) + + val sc = SSLContext.getInstance("TLSv1.2") + sc.init(null, tmf.getTrustManagers, null) + + sc + } + + def get(server: Server, path: String, clientAuth: Boolean = true): String = + ErrorReporting.silenceOutputStreams { + val url = new URL(s"https://localhost:${server.address.getPort}$path") + val conn = url.openConnection().asInstanceOf[HttpsURLConnection] + conn.setRequestMethod("GET") + + if (clientAuth) + conn.setSSLSocketFactory(sslContext.getSocketFactory) + else + conn.setSSLSocketFactory(noAuthClientContext.getSocketFactory) + + Try { + Source.fromInputStream(conn.getInputStream, StandardCharsets.UTF_8.name).getLines().mkString + }.recover { case ex: Throwable => + ex.getMessage + }.toOption + .getOrElse("") + } + + private def blazeServer(sslParameters: SSLParameters) = + ResourceFixture(serverR(sslParameters)) + + /** Test "required" auth mode + */ + blazeServer(TLSParameters(needClientAuth = true).toSSLParameters) + .test("Server should send mTLS request correctly") { server => + assertEquals(get(server, "/dummy", true), "CN=Test,OU=Test,O=Test,L=CA,ST=CA,C=US") + } + blazeServer(TLSParameters(needClientAuth = true).toSSLParameters) + .test("Server should fail for invalid client auth") { server => + assertNotEquals(get(server, "/dummy", false), "CN=Test,OU=Test,O=Test,L=CA,ST=CA,C=US") + } + + /** Test "requested" auth mode + */ + blazeServer(TLSParameters(wantClientAuth = true).toSSLParameters) + .test("Server should send mTLS request correctly with optional auth") { server => + assertEquals(get(server, "/dummy", true), "CN=Test,OU=Test,O=Test,L=CA,ST=CA,C=US") + } + + blazeServer(TLSParameters(wantClientAuth = true).toSSLParameters) + .test("Server should send mTLS request correctly without clientAuth") { server => + assertEquals(get(server, "/noauth", false), "success") + } + +} diff --git a/blaze-server/src/test/scala/org/http4s/blaze/server/BlazeServerSuite.scala b/blaze-server/src/test/scala/org/http4s/blaze/server/BlazeServerSuite.scala new file mode 100644 index 000000000..8e8c1bace --- /dev/null +++ b/blaze-server/src/test/scala/org/http4s/blaze/server/BlazeServerSuite.scala @@ -0,0 +1,260 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s +package blaze +package server + +import cats.effect._ +import cats.effect.unsafe.IORuntime +import cats.effect.unsafe.IORuntimeConfig +import cats.effect.unsafe.Scheduler +import cats.syntax.all._ +import munit.CatsEffectSuite +import munit.TestOptions +import org.http4s.blaze.channel.ChannelOptions +import org.http4s.dsl.io._ +import org.http4s.internal.threads._ +import org.http4s.multipart.Multipart +import org.http4s.server.Server + +import java.net.HttpURLConnection +import java.net.URL +import java.nio.charset.StandardCharsets +import java.util.concurrent.ScheduledExecutorService +import java.util.concurrent.ScheduledThreadPoolExecutor +import java.util.concurrent.TimeUnit +import scala.concurrent.ExecutionContext +import scala.concurrent.duration._ +import scala.io.Source + +class BlazeServerSuite extends CatsEffectSuite { + + override implicit lazy val munitIoRuntime: IORuntime = { + val TestScheduler: ScheduledExecutorService = { + val s = + new ScheduledThreadPoolExecutor( + 2, + threadFactory(i => s"blaze-server-suite-scheduler-$i", true), + ) + s.setKeepAliveTime(10L, TimeUnit.SECONDS) + s.allowCoreThreadTimeOut(true) + s + } + + val blockingPool = newBlockingPool("blaze-server-suite-blocking") + val computePool = newDaemonPool("blaze-server-suite-compute", timeout = true) + val scheduledExecutor = TestScheduler + IORuntime.apply( + ExecutionContext.fromExecutor(computePool), + ExecutionContext.fromExecutor(blockingPool), + Scheduler.fromScheduledExecutor(scheduledExecutor), + () => { + blockingPool.shutdown() + computePool.shutdown() + scheduledExecutor.shutdown() + }, + IORuntimeConfig(), + ) + } + + override def afterAll(): Unit = munitIoRuntime.shutdown() + + private def builder = + BlazeServerBuilder[IO] + .withResponseHeaderTimeout(1.second) + + private val service: HttpApp[IO] = HttpApp { + case GET -> Root / "thread" / "routing" => + val thread = Thread.currentThread.getName + Ok(thread) + + case GET -> Root / "thread" / "effect" => + IO(Thread.currentThread.getName).flatMap(Ok(_)) + + case req @ POST -> Root / "echo" => + Ok(req.body) + + case _ -> Root / "never" => + IO.never + + case req @ POST -> Root / "issue2610" => + req.decode[Multipart[IO]] { mp => + Ok(mp.parts.foldMap(_.body)) + } + + case _ => NotFound() + } + + private val serverR = + builder + .bindAny() + .withHttpApp(service) + .resource + + private val blazeServer = + ResourceFixture[Server]( + serverR, + (_: TestOptions, _: Server) => IO.unit, + (_: Server) => IO.sleep(100.milliseconds) *> IO.unit, + ) + + private def get(server: Server, path: String): IO[String] = IO.blocking { + AutoCloseableResource.resource( + Source + .fromURL(new URL(s"http://127.0.0.1:${server.address.getPort}$path")) + )(_.getLines().mkString) + } + + private def getStatus(server: Server, path: String): IO[Status] = { + val url = new URL(s"http://127.0.0.1:${server.address.getPort}$path") + for { + conn <- IO.blocking(url.openConnection().asInstanceOf[HttpURLConnection]) + _ = conn.setRequestMethod("GET") + status <- IO + .blocking(conn.getResponseCode()) + .flatMap(code => IO.fromEither(Status.fromInt(code))) + } yield status + } + + private def post(server: Server, path: String, body: String): IO[String] = IO.blocking { + val url = new URL(s"http://127.0.0.1:${server.address.getPort}$path") + val conn = url.openConnection().asInstanceOf[HttpURLConnection] + val bytes = body.getBytes(StandardCharsets.UTF_8) + conn.setRequestMethod("POST") + conn.setRequestProperty("Content-Length", bytes.size.toString) + conn.setDoOutput(true) + conn.getOutputStream.write(bytes) + + AutoCloseableResource.resource( + Source.fromInputStream(conn.getInputStream, StandardCharsets.UTF_8.name) + )(_.getLines().mkString) + } + + private def postChunkedMultipart( + server: Server, + path: String, + boundary: String, + body: String, + ): IO[String] = + IO.blocking { + val url = new URL(s"http://127.0.0.1:${server.address.getPort}$path") + val conn = url.openConnection().asInstanceOf[HttpURLConnection] + val bytes = body.getBytes(StandardCharsets.UTF_8) + conn.setRequestMethod("POST") + conn.setChunkedStreamingMode(-1) + conn.setRequestProperty("Content-Type", s"""multipart/form-data; boundary="$boundary"""") + conn.setDoOutput(true) + conn.getOutputStream.write(bytes) + + AutoCloseableResource.resource( + Source.fromInputStream(conn.getInputStream, StandardCharsets.UTF_8.name) + )(_.getLines().mkString) + } + + blazeServer.test("route requests on the service executor".flaky) { server => + get(server, "/thread/routing").map(_.startsWith("blaze-server-suite-compute-")).assert + } + + blazeServer.test("execute the service task on the service executor") { server => + get(server, "/thread/effect").map(_.startsWith("blaze-server-suite-compute-")).assert + } + + blazeServer.test("be able to echo its input") { server => + val input = """{ "Hello": "world" }""" + post(server, "/echo", input).map(_.startsWith(input)).assert + } + + blazeServer.test("return a 503 if the server doesn't respond") { server => + getStatus(server, "/never").assertEquals(Status.ServiceUnavailable) + } + + blazeServer.test("reliably handle multipart requests") { server => + val body = + """|--aa + |server: Server, Content-Disposition: form-data; name="a" + |Content-Length: 1 + | + |a + |--aa--""".stripMargin.replace("\n", "\r\n") + + // This is flaky due to Blaze threading and Java connection pooling. + (1 to 100).toList.traverse { _ => + postChunkedMultipart(server, "/issue2610", "aa", body).assertEquals("a") + } + } + + blazeServer.test("ChannelOptions should default to empty") { _ => + assertEquals(builder.channelOptions, ChannelOptions(Vector.empty)) + } + blazeServer.test("ChannelOptions should set socket send buffer size") { _ => + assertEquals(builder.withSocketSendBufferSize(8192).socketSendBufferSize, Some(8192)) + } + blazeServer.test("ChannelOptions should set socket receive buffer size") { _ => + assertEquals(builder.withSocketReceiveBufferSize(8192).socketReceiveBufferSize, Some(8192)) + } + blazeServer.test("ChannelOptions should set socket keepalive") { _ => + assertEquals(builder.withSocketKeepAlive(true).socketKeepAlive, Some(true)) + } + blazeServer.test("ChannelOptions should set socket reuse address") { _ => + assertEquals(builder.withSocketReuseAddress(true).socketReuseAddress, Some(true)) + } + blazeServer.test("ChannelOptions should set TCP nodelay") { _ => + assertEquals(builder.withTcpNoDelay(true).tcpNoDelay, Some(true)) + } + blazeServer.test("ChannelOptions should unset socket send buffer size") { _ => + assertEquals( + builder + .withSocketSendBufferSize(8192) + .withDefaultSocketSendBufferSize + .socketSendBufferSize, + None, + ) + } + blazeServer.test("ChannelOptions should unset socket receive buffer size") { _ => + assertEquals( + builder + .withSocketReceiveBufferSize(8192) + .withDefaultSocketReceiveBufferSize + .socketReceiveBufferSize, + None, + ) + } + blazeServer.test("ChannelOptions should unset socket keepalive") { _ => + assertEquals(builder.withSocketKeepAlive(true).withDefaultSocketKeepAlive.socketKeepAlive, None) + } + blazeServer.test("ChannelOptions should unset socket reuse address") { _ => + assertEquals( + builder + .withSocketReuseAddress(true) + .withDefaultSocketReuseAddress + .socketReuseAddress, + None, + ) + } + blazeServer.test("ChannelOptions should unset TCP nodelay") { _ => + assertEquals(builder.withTcpNoDelay(true).withDefaultTcpNoDelay.tcpNoDelay, None) + } + blazeServer.test("ChannelOptions should overwrite keys") { _ => + assertEquals( + builder + .withSocketSendBufferSize(8192) + .withSocketSendBufferSize(4096) + .socketSendBufferSize, + Some(4096), + ) + } +} diff --git a/blaze-server/src/test/scala/org/http4s/blaze/server/ErrorReporting.scala b/blaze-server/src/test/scala/org/http4s/blaze/server/ErrorReporting.scala new file mode 100644 index 000000000..dfe1a6c4f --- /dev/null +++ b/blaze-server/src/test/scala/org/http4s/blaze/server/ErrorReporting.scala @@ -0,0 +1,129 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Copyright 2013-2020 http4s.org + * + * SPDX-License-Identifier: Apache-2.0 + * + * Based on https://github.com/typelevel/cats-effect/blob/v1.0.0/core/shared/src/test/scala/cats/effect/internals/TestUtils.scala + * Copyright (c) 2017-2018 The Typelevel Cats-effect Project Developers + */ + +package org.http4s +package blaze.server + +import cats.Monad +import cats.syntax.all._ +import org.http4s.headers.Connection +import org.http4s.headers.`Content-Length` + +import java.io.ByteArrayOutputStream +import java.io.OutputStream +import java.io.PrintStream +import scala.util.control.NonFatal + +object NullOutStream extends OutputStream { + override def write(b: Int): Unit = { + // do nothing + } +} + +object ErrorReporting { + + /** Silences System.out and System.err streams for the duration of thunk. + * Restores the original streams before exiting. + */ + def silenceOutputStreams[R](thunk: => R): R = + synchronized { + val originalOut = System.out + val originalErr = System.err + + // Redirect output to dummy stream + val fakeOutStream = new PrintStream(NullOutStream) + val fakeErrStream = new PrintStream(NullOutStream) + System.setOut(fakeOutStream) + System.setErr(fakeErrStream) + try thunk + finally { + // Set back the original streams + System.setOut(originalOut) + System.setErr(originalErr) + } + } + + /** Returns an ErrorHandler that does not log + */ + def silentErrorHandler[F[_], G[_]](implicit + F: Monad[F] + ): Request[F] => PartialFunction[Throwable, F[Response[G]]] = + req => { + case mf: MessageFailure => + mf.toHttpResponse[G](req.httpVersion).pure[F] + case NonFatal(_) => + F.pure( + Response( + Status.InternalServerError, + req.httpVersion, + Headers( + Connection.close, + `Content-Length`.zero, + ), + ) + ) + } + + /** Silences `System.err`, only printing the output in case exceptions are + * thrown by the executed `thunk`. + */ + def silenceSystemErr[A](thunk: => A): A = + synchronized { + // Silencing System.err + val oldErr = System.err + val outStream = new ByteArrayOutputStream() + val fakeErr = new PrintStream(outStream) + System.setErr(fakeErr) + try { + val result = thunk + System.setErr(oldErr) + result + } catch { + case NonFatal(e) => + System.setErr(oldErr) + // In case of errors, print whatever was caught + fakeErr.close() + val out = outStream.toString("utf-8") + if (out.nonEmpty) oldErr.println(out) + throw e + } + } + + /** Catches `System.err` output, for testing purposes. + */ + def catchSystemErr(thunk: => Unit): String = + synchronized { + val oldErr = System.err + val outStream = new ByteArrayOutputStream() + val fakeErr = new PrintStream(outStream) + System.setErr(fakeErr) + try thunk + finally { + System.setErr(oldErr) + fakeErr.close() + } + outStream.toString("utf-8") + } +} diff --git a/blaze-server/src/test/scala/org/http4s/blaze/server/Http1ServerStageSpec.scala b/blaze-server/src/test/scala/org/http4s/blaze/server/Http1ServerStageSpec.scala new file mode 100644 index 000000000..e05ae4887 --- /dev/null +++ b/blaze-server/src/test/scala/org/http4s/blaze/server/Http1ServerStageSpec.scala @@ -0,0 +1,643 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s +package blaze +package server + +import cats.data.Kleisli +import cats.effect._ +import cats.effect.kernel.Deferred +import cats.effect.std.Dispatcher +import cats.syntax.all._ +import munit.CatsEffectSuite +import org.http4s.blaze.pipeline.Command.Connected +import org.http4s.blaze.pipeline.Command.Disconnected +import org.http4s.blaze.util.TickWheelExecutor +import org.http4s.blazecore.ResponseParser +import org.http4s.blazecore.SeqTestHead +import org.http4s.dsl.io._ +import org.http4s.headers.Date +import org.http4s.headers.`Content-Length` +import org.http4s.headers.`Transfer-Encoding` +import org.http4s.syntax.all._ +import org.http4s.websocket.WebSocketContext +import org.http4s.{headers => H} +import org.typelevel.ci._ +import org.typelevel.vault._ + +import java.nio.ByteBuffer +import java.nio.charset.StandardCharsets +import scala.annotation.nowarn +import scala.concurrent.duration._ + +class Http1ServerStageSpec extends CatsEffectSuite { + + private val fixture = ResourceFixture(Resource.make(IO.delay(new TickWheelExecutor())) { twe => + IO.delay(twe.shutdown()) + }) + + // todo replace with DispatcherIOFixture + val dispatcher = new Fixture[Dispatcher[IO]]("dispatcher") { + + private var d: Dispatcher[IO] = null + private var shutdown: IO[Unit] = null + def apply() = d + override def beforeAll(): Unit = { + val dispatcherAndShutdown = Dispatcher[IO].allocated.unsafeRunSync() + shutdown = dispatcherAndShutdown._2 + d = dispatcherAndShutdown._1 + } + override def afterAll(): Unit = + shutdown.unsafeRunSync() + } + override def munitFixtures = List(dispatcher) + + def makeString(b: ByteBuffer): String = { + val p = b.position() + val a = new Array[Byte](b.remaining()) + b.get(a).position(p) + new String(a) + } + + def parseAndDropDate(buff: ByteBuffer): (Status, Set[Header.Raw], String) = + dropDate(ResponseParser.apply(buff)) + + def dropDate(resp: (Status, Set[Header.Raw], String)): (Status, Set[Header.Raw], String) = { + val hds = resp._2.filter(_.name != Header[Date].name) + (resp._1, hds, resp._3) + } + + def runRequest( + tw: TickWheelExecutor, + req: Seq[String], + httpApp: HttpApp[IO], + maxReqLine: Int = 4 * 1024, + maxHeaders: Int = 16 * 1024, + ): SeqTestHead = { + val head = new SeqTestHead( + req.map(s => ByteBuffer.wrap(s.getBytes(StandardCharsets.ISO_8859_1))) + ) + val httpStage = server.Http1ServerStage[IO]( + httpApp, + () => Vault.empty, + munitExecutionContext, + wsKey = Key.newKey[SyncIO, WebSocketContext[IO]].unsafeRunSync(), + maxReqLine, + maxHeaders, + 10 * 1024, + ErrorReporting.silentErrorHandler, + 30.seconds, + 30.seconds, + tw, + dispatcher(), + None, + ) + + pipeline.LeafBuilder(httpStage).base(head) + head.sendInboundCommand(Connected) + head + } + + val req = "GET /foo HTTP/1.1\r\nheader: value\r\n\r\n" + + private val routes = HttpRoutes + .of[IO] { case _ => + Ok("foo!") + } + .orNotFound + + fixture.test("Http1ServerStage: Invalid Lengths should fail on too long of a request line") { + tickwheel => + runRequest(tickwheel, Seq(req), routes, maxReqLine = 1).result.map { buff => + val str = StandardCharsets.ISO_8859_1.decode(buff.duplicate()).toString + // make sure we don't have signs of chunked encoding. + assert(str.contains("400 Bad Request")) + } + } + + fixture.test("Http1ServerStage: Invalid Lengths should fail on too long of a header") { + tickwheel => + runRequest(tickwheel, Seq(req), routes, maxHeaders = 1).result.map { buff => + val str = StandardCharsets.ISO_8859_1.decode(buff.duplicate()).toString + // make sure we don't have signs of chunked encoding. + assert(str.contains("400 Bad Request")) + } + } + + ServerTestRoutes.testRequestResults.zipWithIndex.foreach { + case ((req, (status, headers, resp)), i) => + if (i == 7 || i == 8) // Awful temporary hack + fixture.test( + s"Http1ServerStage: Common responses should Run request $i Run request: --------\n${req + .split("\r\n\r\n")(0)}\n" + ) { tw => + runRequest(tw, Seq(req), ServerTestRoutes()).result + .map(parseAndDropDate) + .map(assertEquals(_, (status, headers, resp))) + + } + else + fixture.test( + s"Http1ServerStage: Common responses should Run request $i Run request: --------\n${req + .split("\r\n\r\n")(0)}\n" + ) { tw => + runRequest(tw, Seq(req), ServerTestRoutes()).result + .map(parseAndDropDate) + .map(assertEquals(_, (status, headers, resp))) + + } + } + + private val exceptionService = HttpRoutes + .of[IO] { + case GET -> Root / "sync" => sys.error("Synchronous error!") + case GET -> Root / "async" => IO.raiseError(new Exception("Asynchronous error!")) + case GET -> Root / "sync" / "422" => + throw InvalidMessageBodyFailure("lol, I didn't even look") + case GET -> Root / "async" / "422" => + IO.raiseError(InvalidMessageBodyFailure("lol, I didn't even look")) + } + .orNotFound + + private def runError(tw: TickWheelExecutor, path: String) = + runRequest(tw, List(path), exceptionService).result + .map(parseAndDropDate) + .map { case (s, h, r) => + val close = h.exists { h => + h.name == ci"connection" && h.value == "close" + } + (s, close, r) + } + + fixture.test("Http1ServerStage: Errors should Deal with synchronous errors") { tw => + val path = "GET /sync HTTP/1.1\r\nConnection:keep-alive\r\n\r\n" + runError(tw, path).map { case (s, c, _) => + assert(c) + assertEquals(s, InternalServerError) + } + } + + fixture.test("Http1ServerStage: Errors should Call toHttpResponse on synchronous errors") { tw => + val path = "GET /sync/422 HTTP/1.1\r\nConnection:keep-alive\r\n\r\n" + runError(tw, path).map { case (s, c, _) => + assert(!c) + assertEquals(s, UnprocessableEntity) + } + } + + fixture.test("Http1ServerStage: Errors should Deal with asynchronous errors") { tw => + val path = "GET /async HTTP/1.1\r\nConnection:keep-alive\r\n\r\n" + runError(tw, path).map { case (s, c, _) => + assert(c) + assertEquals(s, InternalServerError) + } + } + + fixture.test("Http1ServerStage: Errors should Call toHttpResponse on asynchronous errors") { tw => + val path = "GET /async/422 HTTP/1.1\r\nConnection:keep-alive\r\n\r\n" + runError(tw, path).map { case (s, c, _) => + assert(!c) + assertEquals(s, UnprocessableEntity) + } + } + + fixture.test("Http1ServerStage: Errors should Handle parse error") { tw => + val path = "THIS\u0000IS\u0000NOT\u0000HTTP" + runError(tw, path).map { case (s, c, _) => + assert(c) + assertEquals(s, BadRequest) + } + } + + fixture.test( + "Http1ServerStage: routes should Do not send `Transfer-Encoding: identity` response" + ) { tw => + val routes = HttpRoutes + .of[IO] { case _ => + val headers = Headers(H.`Transfer-Encoding`(TransferCoding.identity)) + IO.pure( + Response[IO](headers = headers) + .withEntity("hello world") + ) + } + .orNotFound + + // The first request will get split into two chunks, leaving the last byte off + val req = "GET /foo HTTP/1.1\r\n\r\n" + + runRequest(tw, Seq(req), routes).result.map { buff => + val str = StandardCharsets.ISO_8859_1.decode(buff.duplicate()).toString + // make sure we don't have signs of chunked encoding. + assert(!str.contains("0\r\n\r\n")) + assert(str.contains("hello world")) + + val (_, hdrs, _) = ResponseParser.apply(buff) + assert(!hdrs.exists(_.name == `Transfer-Encoding`.name)) + } + } + + fixture.test( + "Http1ServerStage: routes should Do not send an entity or entity-headers for a status that doesn't permit it" + ) { tw => + val routes: HttpApp[IO] = HttpRoutes + .of[IO] { case _ => + IO.pure( + Response[IO](status = Status.NotModified) + .putHeaders(`Transfer-Encoding`(TransferCoding.chunked)) + .withEntity("Foo!") + ) + } + .orNotFound + + val req = "GET /foo HTTP/1.1\r\n\r\n" + + runRequest(tw, Seq(req), routes).result.map { buf => + val (status, hs, body) = ResponseParser.parseBuffer(buf) + hs.foreach { h => + assert(`Content-Length`.parse(h.value).isLeft) + } + assertEquals(body, "") + assertEquals(status, Status.NotModified) + } + } + + fixture.test("Http1ServerStage: routes should Add a date header") { tw => + val routes = HttpRoutes + .of[IO] { case req => + IO.pure(Response(body = req.body)) + } + .orNotFound + + // The first request will get split into two chunks, leaving the last byte off + val req1 = "POST /sync HTTP/1.1\r\nConnection:keep-alive\r\nContent-Length: 4\r\n\r\ndone" + + runRequest(tw, Seq(req1), routes).result.map { buff => + // Both responses must succeed + val (_, hdrs, _) = ResponseParser.apply(buff) + assert(hdrs.exists(_.name == Header[Date].name)) + } + } + + fixture.test("Http1ServerStage: routes should Honor an explicitly added date header") { tw => + val dateHeader = Date(HttpDate.Epoch) + val routes = HttpRoutes + .of[IO] { case req => + IO.pure(Response(body = req.body).withHeaders(dateHeader)) + } + .orNotFound + + // The first request will get split into two chunks, leaving the last byte off + val req1 = "POST /sync HTTP/1.1\r\nConnection:keep-alive\r\nContent-Length: 4\r\n\r\ndone" + + runRequest(tw, Seq(req1), routes).result.map { buff => + // Both responses must succeed + val (_, hdrs, _) = ResponseParser.apply(buff) + + val result = hdrs.find(_.name == Header[Date].name).map(_.value) + assertEquals(result, Some(dateHeader.value)) + } + } + + fixture.test( + "Http1ServerStage: routes should Handle routes that echos full request body for non-chunked" + ) { tw => + val routes = HttpRoutes + .of[IO] { case req => + IO.pure(Response(body = req.body)) + } + .orNotFound + + // The first request will get split into two chunks, leaving the last byte off + val req1 = "POST /sync HTTP/1.1\r\nConnection:keep-alive\r\nContent-Length: 4\r\n\r\ndone" + val (r11, r12) = req1.splitAt(req1.length - 1) + + runRequest(tw, Seq(r11, r12), routes).result.map { buff => + // Both responses must succeed + assertEquals( + parseAndDropDate(buff), + (Ok, Set(H.`Content-Length`.unsafeFromLong(4).toRaw1), "done"), + ) + } + } + + fixture.test( + "Http1ServerStage: routes should Handle routes that consumes the full request body for non-chunked" + ) { tw => + val routes = HttpRoutes + .of[IO] { case req => + req.as[String].map { s => + Response().withEntity("Result: " + s) + } + } + .orNotFound + + // The first request will get split into two chunks, leaving the last byte off + val req1 = "POST /sync HTTP/1.1\r\nConnection:keep-alive\r\nContent-Length: 4\r\n\r\ndone" + val (r11, r12) = req1.splitAt(req1.length - 1) + + runRequest(tw, Seq(r11, r12), routes).result.map { buff => + // Both responses must succeed + assertEquals( + parseAndDropDate(buff), + ( + Ok, + Set( + H.`Content-Length`.unsafeFromLong(8 + 4).toRaw1, + H.`Content-Type`(MediaType.text.plain, Charset.`UTF-8`).toRaw1, + ), + "Result: done", + ), + ) + } + } + + fixture.test( + "Http1ServerStage: routes should Maintain the connection if the body is ignored but was already read to completion by the Http1Stage" + ) { tw => + val routes = HttpRoutes + .of[IO] { case _ => + IO.pure(Response().withEntity("foo")) + } + .orNotFound + + // The first request will get split into two chunks, leaving the last byte off + val req1 = "POST /sync HTTP/1.1\r\nConnection:keep-alive\r\nContent-Length: 4\r\n\r\ndone" + val req2 = "POST /sync HTTP/1.1\r\nConnection:keep-alive\r\nContent-Length: 5\r\n\r\ntotal" + + runRequest(tw, Seq(req1, req2), routes).result.map { buff => + val hs = Set( + H.`Content-Type`(MediaType.text.plain, Charset.`UTF-8`).toRaw1, + H.`Content-Length`.unsafeFromLong(3).toRaw1, + ) + // Both responses must succeed + assertEquals(dropDate(ResponseParser.parseBuffer(buff)), (Ok, hs, "foo")) + assertEquals(dropDate(ResponseParser.parseBuffer(buff)), (Ok, hs, "foo")) + } + } + + fixture.test( + "Http1ServerStage: routes should Drop the connection if the body is ignored and was not read to completion by the Http1Stage" + ) { tw => + val routes = HttpRoutes + .of[IO] { case _ => + IO.pure(Response().withEntity("foo")) + } + .orNotFound + + // The first request will get split into two chunks, leaving the last byte off + val req1 = "POST /sync HTTP/1.1\r\nConnection:keep-alive\r\nContent-Length: 4\r\n\r\ndone" + val (r11, r12) = req1.splitAt(req1.length - 1) + + val req2 = "POST /sync HTTP/1.1\r\nConnection:keep-alive\r\nContent-Length: 5\r\n\r\ntotal" + + runRequest(tw, Seq(r11, r12, req2), routes).result.map { buff => + val hs = Set( + H.`Content-Type`(MediaType.text.plain, Charset.`UTF-8`).toRaw1, + H.`Content-Length`.unsafeFromLong(3).toRaw1, + ) + // Both responses must succeed + assertEquals(dropDate(ResponseParser.parseBuffer(buff)), (Ok, hs, "foo")) + assertEquals(buff.remaining(), 0) + } + } + + fixture.test( + "Http1ServerStage: routes should Handle routes that runs the request body for non-chunked" + ) { tw => + val routes = HttpRoutes + .of[IO] { case req => + req.body.compile.drain *> IO.pure(Response().withEntity("foo")) + } + .orNotFound + + // The first request will get split into two chunks, leaving the last byte off + val req1 = "POST /sync HTTP/1.1\r\nConnection:keep-alive\r\nContent-Length: 4\r\n\r\ndone" + val (r11, r12) = req1.splitAt(req1.length - 1) + val req2 = "POST /sync HTTP/1.1\r\nConnection:keep-alive\r\nContent-Length: 5\r\n\r\ntotal" + + runRequest(tw, Seq(r11, r12, req2), routes).result.map { buff => + val hs = Set( + H.`Content-Type`(MediaType.text.plain, Charset.`UTF-8`).toRaw1, + H.`Content-Length`.unsafeFromLong(3).toRaw1, + ) + // Both responses must succeed + assertEquals(dropDate(ResponseParser.parseBuffer(buff)), (Ok, hs, "foo")) + assertEquals(dropDate(ResponseParser.parseBuffer(buff)), (Ok, hs, "foo")) + } + } + + // Think of this as drunk HTTP pipelining + fixture.test("Http1ServerStage: routes should Not die when two requests come in back to back") { + tw => + val routes = HttpRoutes + .of[IO] { case req => + IO.pure(Response(body = req.body)) + } + .orNotFound + + // The first request will get split into two chunks, leaving the last byte off + val req1 = "POST /sync HTTP/1.1\r\nConnection:keep-alive\r\nContent-Length: 4\r\n\r\ndone" + val req2 = "POST /sync HTTP/1.1\r\nConnection:keep-alive\r\nContent-Length: 5\r\n\r\ntotal" + + runRequest(tw, Seq(req1 + req2), routes).result.map { buff => + // Both responses must succeed + assertEquals( + dropDate(ResponseParser.parseBuffer(buff)), + (Ok, Set(H.`Content-Length`.unsafeFromLong(4).toRaw1), "done"), + ) + assertEquals( + dropDate(ResponseParser.parseBuffer(buff)), + (Ok, Set(H.`Content-Length`.unsafeFromLong(5).toRaw1), "total"), + ) + } + } + + fixture.test( + "Http1ServerStage: routes should Handle using the request body as the response body" + ) { tw => + val routes = HttpRoutes + .of[IO] { case req => + IO.pure(Response(body = req.body)) + } + .orNotFound + + // The first request will get split into two chunks, leaving the last byte off + val req1 = "POST /sync HTTP/1.1\r\nConnection:keep-alive\r\nContent-Length: 4\r\n\r\ndone" + val req2 = "POST /sync HTTP/1.1\r\nConnection:keep-alive\r\nContent-Length: 5\r\n\r\ntotal" + + runRequest(tw, Seq(req1, req2), routes).result.map { buff => + // Both responses must succeed + assertEquals( + dropDate(ResponseParser.parseBuffer(buff)), + (Ok, Set(H.`Content-Length`.unsafeFromLong(4).toRaw1), "done"), + ) + assertEquals( + dropDate(ResponseParser.parseBuffer(buff)), + (Ok, Set(H.`Content-Length`.unsafeFromLong(5).toRaw1), "total"), + ) + } + } + + private def req(path: String) = + s"POST /$path HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n" + + "3\r\n" + + "foo\r\n" + + "0\r\n" + + "Foo:Bar\r\n\r\n" + + private val routes2 = HttpRoutes + .of[IO] { + case req if req.pathInfo === path"/foo" => + for { + _ <- req.body.compile.drain + hs <- req.trailerHeaders + resp <- Ok(hs.headers.mkString) + } yield resp + + case req if req.pathInfo === path"/bar" => + for { + // Don't run the body + hs <- req.trailerHeaders + resp <- Ok(hs.headers.mkString) + } yield resp + } + .orNotFound + + fixture.test("Http1ServerStage: routes should Handle trailing headers") { tw => + (runRequest(tw, Seq(req("foo")), routes2).result).map { buff => + val results = dropDate(ResponseParser.parseBuffer(buff)) + assertEquals(results._1, Ok) + assertEquals(results._3, "Foo: Bar") + } + } + + fixture.test( + "Http1ServerStage: routes should Fail if you use the trailers before they have resolved" + ) { tw => + runRequest(tw, Seq(req("bar")), routes2).result.map { buff => + val results = dropDate(ResponseParser.parseBuffer(buff)) + assertEquals(results._1, InternalServerError) + } + } + + fixture.test("Http1ServerStage: routes should cancels on stage shutdown".flaky) { tw => + Deferred[IO, Unit] + .flatMap { canceled => + Deferred[IO, Unit].flatMap { gate => + val req = + "POST /sync HTTP/1.1\r\nConnection:keep-alive\r\nContent-Length: 4\r\n\r\ndone" + val app: HttpApp[IO] = HttpApp { _ => + gate.complete(()) >> canceled.complete(()) >> IO.never[Response[IO]] + } + for { + head <- IO(runRequest(tw, List(req), app)) + _ <- gate.get + _ <- IO(head.closePipeline(None)) + _ <- canceled.get + } yield () + } + } + } + + fixture.test("Http1ServerStage: routes should Disconnect if we read an EOF") { tw => + val head = runRequest(tw, Seq.empty, Kleisli.liftF(Ok(""))) + head.result.map { _ => + assertEquals(head.closeCauses, Vector(None)) + } + } + + fixture.test("Prevent response splitting attacks on status reason phrase") { tw => + val rawReq = "GET /?reason=%0D%0AEvil:true%0D%0A HTTP/1.0\r\n\r\n" + @nowarn("cat=deprecation") + val head = runRequest( + tw, + List(rawReq), + HttpApp { req => + Response[IO](Status.NoContent.withReason(req.params("reason"))).pure[IO] + }, + ) + head.result.map { buff => + val (_, headers, _) = ResponseParser.parseBuffer(buff) + assertEquals(headers.find(_.name === ci"Evil"), None) + } + } + + fixture.test("Prevent response splitting attacks on field name") { tw => + val rawReq = "GET /?fieldName=Fine:%0D%0AEvil:true%0D%0A HTTP/1.0\r\n\r\n" + val head = runRequest( + tw, + List(rawReq), + HttpApp { req => + Response[IO](Status.NoContent).putHeaders(req.params("fieldName") -> "oops").pure[IO] + }, + ) + head.result.map { buff => + val (_, headers, _) = ResponseParser.parseBuffer(buff) + assertEquals(headers.find(_.name === ci"Evil"), None) + } + } + + fixture.test("Prevent response splitting attacks on field value") { tw => + val rawReq = "GET /?fieldValue=%0D%0AEvil:true%0D%0A HTTP/1.0\r\n\r\n" + val head = runRequest( + tw, + List(rawReq), + HttpApp { req => + Response[IO](Status.NoContent) + .putHeaders("X-Oops" -> req.params("fieldValue")) + .pure[IO] + }, + ) + head.result.map { buff => + val (_, headers, _) = ResponseParser.parseBuffer(buff) + assertEquals(headers.find(_.name === ci"Evil"), None) + } + + fixture.test("Http1ServerStage: don't deadlock TickWheelExecutor with uncancelable request") { + tw => + val reqUncancelable = List("GET /uncancelable HTTP/1.0\r\n\r\n") + val reqCancelable = List("GET /cancelable HTTP/1.0\r\n\r\n") + + (for { + uncancelableStarted <- Deferred[IO, Unit] + uncancelableCanceled <- Deferred[IO, Unit] + cancelableStarted <- Deferred[IO, Unit] + cancelableCanceled <- Deferred[IO, Unit] + app = HttpApp[IO] { + case req if req.pathInfo === path"/uncancelable" => + uncancelableStarted.complete(()) *> + IO.uncancelable { poll => + poll(uncancelableCanceled.complete(())) *> + cancelableCanceled.get + }.as(Response[IO]()) + case _ => + cancelableStarted.complete(()) *> IO.never.guarantee( + cancelableCanceled.complete(()).void + ) + } + head <- IO(runRequest(tw, reqUncancelable, app)) + _ <- uncancelableStarted.get + _ <- uncancelableCanceled.get + _ <- IO(head.sendInboundCommand(Disconnected)) + head2 <- IO(runRequest(tw, reqCancelable, app)) + _ <- cancelableStarted.get + _ <- IO(head2.sendInboundCommand(Disconnected)) + _ <- cancelableCanceled.get + } yield ()).assert + } + } +} diff --git a/blaze-server/src/test/scala/org/http4s/blaze/server/ServerTestRoutes.scala b/blaze-server/src/test/scala/org/http4s/blaze/server/ServerTestRoutes.scala new file mode 100644 index 000000000..832001d32 --- /dev/null +++ b/blaze-server/src/test/scala/org/http4s/blaze/server/ServerTestRoutes.scala @@ -0,0 +1,155 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s +package blaze +package server + +import cats.data.Kleisli +import cats.effect._ +import fs2.Stream._ +import org.http4s.Charset._ +import org.http4s.dsl.Http4sDsl +import org.http4s.headers._ +import org.http4s.implicits._ +import org.typelevel.ci._ + +object ServerTestRoutes extends Http4sDsl[IO] { + private val textPlain = `Content-Type`(MediaType.text.plain, `UTF-8`).toRaw1 + private val connClose = Connection.close.toRaw1 + private val connKeep = Connection(ci"keep-alive").toRaw1 + private val chunked = `Transfer-Encoding`(TransferCoding.chunked).toRaw1 + + def length(l: Long): Header.Raw = `Content-Length`.unsafeFromLong(l).toRaw1 + def testRequestResults: Seq[(String, (Status, Set[Header.Raw], String))] = + Seq( + ("GET /get HTTP/1.0\r\n\r\n", (Status.Ok, Set(length(3), textPlain), "get")), + // /////////////////////////////// + ("GET /get HTTP/1.1\r\n\r\n", (Status.Ok, Set(length(3), textPlain), "get")), + // /////////////////////////////// + ( + "GET /get HTTP/1.0\r\nConnection:keep-alive\r\n\r\n", + (Status.Ok, Set(length(3), textPlain, connKeep), "get"), + ), + // /////////////////////////////// + ( + "GET /get HTTP/1.1\r\nConnection:keep-alive\r\n\r\n", + (Status.Ok, Set(length(3), textPlain), "get"), + ), + // /////////////////////////////// + ( + "GET /get HTTP/1.1\r\nConnection:close\r\n\r\n", + (Status.Ok, Set(length(3), textPlain, connClose), "get"), + ), + // /////////////////////////////// + ( + "GET /get HTTP/1.0\r\nConnection:close\r\n\r\n", + (Status.Ok, Set(length(3), textPlain, connClose), "get"), + ), + // /////////////////////////////// + ( + "GET /get HTTP/1.1\r\nConnection:close\r\n\r\n", + (Status.Ok, Set(length(3), textPlain, connClose), "get"), + ), + ("GET /chunked HTTP/1.1\r\n\r\n", (Status.Ok, Set(textPlain, chunked), "chunk")), + // /////////////////////////////// + ( + "GET /chunked HTTP/1.1\r\nConnection:close\r\n\r\n", + (Status.Ok, Set(textPlain, chunked, connClose), "chunk"), + ), + // /////////////////////////////// Content-Length and Transfer-Encoding free responses for HTTP/1.0 + ("GET /chunked HTTP/1.0\r\n\r\n", (Status.Ok, Set(textPlain), "chunk")), + // /////////////////////////////// + ( + "GET /chunked HTTP/1.0\r\nConnection:Close\r\n\r\n", + (Status.Ok, Set(textPlain, connClose), "chunk"), + ), + // ////////////////////////////// Requests with a body ////////////////////////////////////// + ( + "POST /post HTTP/1.1\r\nContent-Length:3\r\n\r\nfoo", + (Status.Ok, Set(textPlain, length(4)), "post"), + ), + // /////////////////////////////// + ( + "POST /post HTTP/1.1\r\nConnection:close\r\nContent-Length:3\r\n\r\nfoo", + (Status.Ok, Set(textPlain, length(4), connClose), "post"), + ), + // /////////////////////////////// + ( + "POST /post HTTP/1.0\r\nConnection:close\r\nContent-Length:3\r\n\r\nfoo", + (Status.Ok, Set(textPlain, length(4), connClose), "post"), + ), + // /////////////////////////////// + ( + "POST /post HTTP/1.0\r\nContent-Length:3\r\n\r\nfoo", + (Status.Ok, Set(textPlain, length(4)), "post"), + ), + // //////////////////////////////////////////////////////////////////// + ( + "POST /post HTTP/1.1\r\nTransfer-Encoding:chunked\r\n\r\n3\r\nfoo\r\n0\r\n\r\n", + (Status.Ok, Set(textPlain, length(4)), "post"), + ), + // /////////////////////////////// + ( + "POST /post HTTP/1.1\r\nConnection:close\r\nTransfer-Encoding:chunked\r\n\r\n3\r\nfoo\r\n0\r\n\r\n", + (Status.Ok, Set(textPlain, length(4), connClose), "post"), + ), + ( + "POST /post HTTP/1.1\r\nTransfer-Encoding:chunked\r\n\r\n3\r\nfoo\r\n3\r\nbar\r\n0\r\n\r\n", + (Status.Ok, Set(textPlain, length(4)), "post"), + ), + // /////////////////////////////// + ( + "POST /post HTTP/1.1\r\nConnection:Close\r\nTransfer-Encoding:chunked\r\n\r\n3\r\nfoo\r\n0\r\n\r\n", + (Status.Ok, Set(textPlain, length(4), connClose), "post"), + ), + // /////////////////////////////// Check corner cases ////////////////// + ( + "GET /twocodings HTTP/1.0\r\nConnection:Close\r\n\r\n", + (Status.Ok, Set(textPlain, length(3), connClose), "Foo"), + ), + // /////////////// Work with examples that don't have a body ////////////////////// + ("GET /notmodified HTTP/1.1\r\n\r\n", (Status.NotModified, Set(), "")), + ( + "GET /notmodified HTTP/1.0\r\nConnection: Keep-Alive\r\n\r\n", + (Status.NotModified, Set(connKeep), ""), + ), + ) + + def apply(): Kleisli[IO, Request[IO], Response[IO]] = + HttpRoutes + .of[IO] { + case req if req.method == Method.GET && req.pathInfo == path"/get" => + Ok("get") + + case req if req.method == Method.GET && req.pathInfo == path"/chunked" => + Ok(eval(IO.cede *> IO("chu")) ++ eval(IO.cede *> IO("nk"))) + + case req if req.method == Method.POST && req.pathInfo == path"/post" => + Ok("post") + + case req if req.method == Method.GET && req.pathInfo == path"/twocodings" => + Ok("Foo", `Transfer-Encoding`(TransferCoding.chunked)) + + case req if req.method == Method.POST && req.pathInfo == path"/echo" => + Ok(emit("post") ++ req.bodyText) + + // Kind of cheating, as the real NotModified response should have a Date header representing the current? time? + case req if req.method == Method.GET && req.pathInfo == path"/notmodified" => + NotModified() + } + .orNotFound +} diff --git a/build.sbt b/build.sbt index 36b4d1325..44242a028 100644 --- a/build.sbt +++ b/build.sbt @@ -3,20 +3,40 @@ import Dependencies._ val Scala212 = "2.12.15" val Scala213 = "2.13.8" -val Scala3 = "3.0.2" +val Scala3 = "3.1.2" +val http4sVersion = "0.23.11-473-e7e64cb-SNAPSHOT" +val munitCatsEffectVersion = "1.0.7" + +ThisBuild / resolvers += + "s01 snapshots".at("https://s01.oss.sonatype.org/content/repositories/snapshots/") ThisBuild / crossScalaVersions := Seq(Scala3, Scala212, Scala213) ThisBuild / scalaVersion := crossScalaVersions.value.filter(_.startsWith("2.")).last -ThisBuild / tlBaseVersion := "0.15" -ThisBuild / tlVersionIntroduced := Map( - "2.13" -> "0.14.5", - "3" -> "0.15.0" -) +ThisBuild / tlBaseVersion := "0.23" ThisBuild / tlFatalWarningsInCi := !tlIsScala3.value // See SSLStage // 11 and 17 blocked by https://github.com/http4s/blaze/issues/376 ThisBuild / githubWorkflowJavaVersions := Seq(JavaSpec.temurin("8")) +ThisBuild / developers ++= List( + Developer( + "bryce-anderson", + "Bryce L. Anderson", + "bryce.anderson22@gamil.com", + url("https://github.com/bryce-anderson")), + Developer( + "rossabaker", + "Ross A. Baker", + "ross@rossabaker.com", + url("https://github.com/rossabaker")), + Developer( + "ChristopherDavenport", + "Christopher Davenport", + "chris@christopherdavenport.tech", + url("https://github.com/ChristopherDavenport")) +) +ThisBuild / startYear := Some(2014) + lazy val commonSettings = Seq( description := "NIO Framework for Scala", Test / scalacOptions ~= (_.filterNot(Set("-Ywarn-dead-code", "-Wdead-code"))), // because mockito @@ -31,24 +51,7 @@ lazy val commonSettings = Seq( } }, run / fork := true, - developers ++= List( - Developer( - "bryce-anderson", - "Bryce L. Anderson", - "bryce.anderson22@gamil.com", - url("https://github.com/bryce-anderson")), - Developer( - "rossabaker", - "Ross A. Baker", - "ross@rossabaker.com", - url("https://github.com/rossabaker")), - Developer( - "ChristopherDavenport", - "Christopher Davenport", - "chris@christopherdavenport.tech", - url("https://github.com/ChristopherDavenport")) - ), - startYear := Some(2014) + scalafmtConfig := file(".scalafmt.blaze.conf") ) // currently only publishing tags @@ -66,7 +69,7 @@ lazy val blaze = project .enablePlugins(Http4sOrgPlugin) .enablePlugins(NoPublishPlugin) .settings(commonSettings) - .aggregate(core, http, examples) + .aggregate(core, http, blazeCore, blazeServer, blazeClient, examples) lazy val testkit = Project("blaze-testkit", file("testkit")) .enablePlugins(NoPublishPlugin) @@ -117,11 +120,223 @@ lazy val http = Project("blaze-http", file("http")) ) .dependsOn(testkit % Test, core % "test->test;compile->compile") +lazy val blazeCore = project + .in(file("blaze-core")) + .settings( + name := "http4s-blaze-core", + description := "Base library for binding blaze to http4s clients and servers", + startYear := Some(2014), + tlMimaPreviousVersions ++= (0 to 11).map(y => s"0.23.$y").toSet, + libraryDependencies ++= Seq( + "org.http4s" %% "http4s-core" % http4sVersion, + "org.typelevel" %% "munit-cats-effect-3" % munitCatsEffectVersion % Test + ), + mimaBinaryIssueFilters := { + if (tlIsScala3.value) + Seq( + ProblemFilters + .exclude[DirectMissingMethodProblem]("org.http4s.blazecore.util.BodylessWriter.this"), + ProblemFilters + .exclude[DirectMissingMethodProblem]("org.http4s.blazecore.util.BodylessWriter.ec"), + ProblemFilters + .exclude[DirectMissingMethodProblem]("org.http4s.blazecore.util.EntityBodyWriter.ec"), + ProblemFilters + .exclude[DirectMissingMethodProblem]("org.http4s.blazecore.util.CachingChunkWriter.ec"), + ProblemFilters.exclude[DirectMissingMethodProblem]( + "org.http4s.blazecore.util.CachingStaticWriter.this" + ), + ProblemFilters.exclude[DirectMissingMethodProblem]( + "org.http4s.blazecore.util.CachingStaticWriter.ec" + ), + ProblemFilters.exclude[DirectMissingMethodProblem]( + "org.http4s.blazecore.util.FlushingChunkWriter.ec" + ), + ProblemFilters + .exclude[DirectMissingMethodProblem]("org.http4s.blazecore.util.Http2Writer.this"), + ProblemFilters + .exclude[DirectMissingMethodProblem]("org.http4s.blazecore.util.Http2Writer.ec"), + ProblemFilters + .exclude[DirectMissingMethodProblem]("org.http4s.blazecore.util.IdentityWriter.this"), + ProblemFilters + .exclude[DirectMissingMethodProblem]("org.http4s.blazecore.util.IdentityWriter.ec") + ) + else Seq.empty + } + ) + .dependsOn(http) + +lazy val blazeServer = project + .in(file("blaze-server")) + .settings( + name := "http4s-blaze-server", + description := "blaze implementation for http4s servers", + startYear := Some(2014), + tlMimaPreviousVersions ++= (0 to 11).map(y => s"0.23.$y").toSet, + libraryDependencies ++= Seq( + "org.http4s" %% "http4s-server" % http4sVersion, + "org.http4s" %% "http4s-dsl" % http4sVersion % Test + ), + mimaBinaryIssueFilters := Seq( + ProblemFilters.exclude[DirectMissingMethodProblem]( + "org.http4s.blaze.server.BlazeServerBuilder.this" + ), // private + ProblemFilters.exclude[DirectMissingMethodProblem]( + "org.http4s.blaze.server.WebSocketDecoder.this" + ), // private + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.http4s.blaze.server.BlazeServerBuilder.this" + ), // private + ProblemFilters.exclude[MissingClassProblem]( + "org.http4s.blaze.server.BlazeServerBuilder$ExecutionContextConfig" + ), // private + ProblemFilters.exclude[MissingClassProblem]( + "org.http4s.blaze.server.BlazeServerBuilder$ExecutionContextConfig$" + ), // private + ProblemFilters.exclude[MissingClassProblem]( + "org.http4s.blaze.server.BlazeServerBuilder$ExecutionContextConfig$DefaultContext$" + ), // private + ProblemFilters.exclude[MissingClassProblem]( + "org.http4s.blaze.server.BlazeServerBuilder$ExecutionContextConfig$ExplicitContext" + ), // private + ProblemFilters.exclude[MissingClassProblem]( + "org.http4s.blaze.server.BlazeServerBuilder$ExecutionContextConfig$ExplicitContext$" + ), // private + ProblemFilters + .exclude[DirectMissingMethodProblem]("org.http4s.blaze.server.BlazeServerBuilder.this"), + ProblemFilters + .exclude[DirectMissingMethodProblem]("org.http4s.blaze.server.WebSocketDecoder.this") + ) ++ { + if (tlIsScala3.value) + Seq( + ProblemFilters + .exclude[DirectMissingMethodProblem]("org.http4s.blaze.server.Http1ServerStage.apply"), + ProblemFilters + .exclude[DirectMissingMethodProblem]("org.http4s.blaze.server.Http1ServerStage.apply"), + ProblemFilters + .exclude[DirectMissingMethodProblem]("org.http4s.blaze.server.ProtocolSelector.apply"), + ProblemFilters + .exclude[DirectMissingMethodProblem]("org.http4s.blaze.server.ProtocolSelector.apply"), + ProblemFilters.exclude[ReversedMissingMethodProblem]( + "org.http4s.blaze.server.WebSocketSupport.maxBufferSize" + ), + ProblemFilters.exclude[ReversedMissingMethodProblem]( + "org.http4s.blaze.server.WebSocketSupport.webSocketKey" + ) + ) + else Seq.empty, + } + ) + .dependsOn(blazeCore % "compile;test->test") + +lazy val blazeClient = project + .in(file("blaze-client")) + .settings( + name := "http4s-blaze-client", + description := "blaze implementation for http4s clients", + startYear := Some(2014), + tlMimaPreviousVersions ++= (0 to 11).map(y => s"0.23.$y").toSet, + libraryDependencies ++= Seq( + "org.http4s" %% "http4s-client" % http4sVersion, + "org.http4s" %% "http4s-client-testkit" % http4sVersion % Test + ), + mimaBinaryIssueFilters ++= Seq( + // private constructor + ProblemFilters + .exclude[IncompatibleMethTypeProblem]("org.http4s.blaze.client.BlazeClientBuilder.this"), + ProblemFilters + .exclude[IncompatibleMethTypeProblem]("org.http4s.blaze.client.Http1Support.this"), + // These are all private to blaze-client and fallout from from + // the deprecation of org.http4s.client.Connection + ProblemFilters + .exclude[IncompatibleMethTypeProblem]("org.http4s.blaze.client.BasicManager.invalidate"), + ProblemFilters + .exclude[IncompatibleMethTypeProblem]("org.http4s.blaze.client.BasicManager.release"), + ProblemFilters.exclude[MissingTypesProblem]("org.http4s.blaze.client.BlazeConnection"), + ProblemFilters + .exclude[IncompatibleMethTypeProblem]("org.http4s.blaze.client.ConnectionManager.release"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.http4s.blaze.client.ConnectionManager.invalidate" + ), + ProblemFilters + .exclude[ReversedMissingMethodProblem]("org.http4s.blaze.client.ConnectionManager.release"), + ProblemFilters.exclude[ReversedMissingMethodProblem]( + "org.http4s.blaze.client.ConnectionManager.invalidate" + ), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.http4s.blaze.client.ConnectionManager#NextConnection.connection" + ), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.http4s.blaze.client.ConnectionManager#NextConnection.copy" + ), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.http4s.blaze.client.ConnectionManager#NextConnection.copy$default$1" + ), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.http4s.blaze.client.ConnectionManager#NextConnection.this" + ), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.http4s.blaze.client.ConnectionManager#NextConnection.apply" + ), + ProblemFilters.exclude[MissingTypesProblem]("org.http4s.blaze.client.Http1Connection"), + ProblemFilters + .exclude[IncompatibleMethTypeProblem]("org.http4s.blaze.client.PoolManager.release"), + ProblemFilters + .exclude[IncompatibleMethTypeProblem]("org.http4s.blaze.client.PoolManager.invalidate"), + ProblemFilters + .exclude[IncompatibleMethTypeProblem]("org.http4s.blaze.client.BasicManager.this"), + ProblemFilters + .exclude[IncompatibleMethTypeProblem]("org.http4s.blaze.client.ConnectionManager.pool"), + ProblemFilters + .exclude[IncompatibleMethTypeProblem]("org.http4s.blaze.client.ConnectionManager.basic"), + ProblemFilters + .exclude[IncompatibleMethTypeProblem]("org.http4s.blaze.client.PoolManager.this"), + // inside private trait/clas/object + ProblemFilters + .exclude[DirectMissingMethodProblem]("org.http4s.blaze.client.BlazeConnection.runRequest"), + ProblemFilters.exclude[ReversedMissingMethodProblem]( + "org.http4s.blaze.client.BlazeConnection.runRequest" + ), + ProblemFilters + .exclude[DirectMissingMethodProblem]("org.http4s.blaze.client.Http1Connection.runRequest"), + ProblemFilters + .exclude[DirectMissingMethodProblem]("org.http4s.blaze.client.Http1Connection.resetWrite"), + ProblemFilters.exclude[MissingClassProblem]("org.http4s.blaze.client.Http1Connection$Idle"), + ProblemFilters.exclude[MissingClassProblem]("org.http4s.blaze.client.Http1Connection$Idle$"), + ProblemFilters.exclude[MissingClassProblem]("org.http4s.blaze.client.Http1Connection$Read$"), + ProblemFilters + .exclude[MissingClassProblem]("org.http4s.blaze.client.Http1Connection$ReadWrite$"), + ProblemFilters.exclude[MissingClassProblem]("org.http4s.blaze.client.Http1Connection$Write$"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.http4s.blaze.client.Http1Connection.isRecyclable" + ), + ProblemFilters + .exclude[IncompatibleResultTypeProblem]("org.http4s.blaze.client.Connection.isRecyclable"), + ProblemFilters + .exclude[ReversedMissingMethodProblem]("org.http4s.blaze.client.Connection.isRecyclable") + ) ++ { + if (tlIsScala3.value) + Seq( + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.http4s.blaze.client.ConnectionManager#NextConnection._1" + ) + ) + else Seq.empty + } + ) + .dependsOn(blazeCore % "compile;test->test") + lazy val examples = Project("blaze-examples", file("examples")) .enablePlugins(NoPublishPlugin) .settings(commonSettings) .settings(Revolver.settings) - .dependsOn(http) + .settings( + libraryDependencies ++= Seq( + "org.http4s" %% "http4s-dsl" % http4sVersion, + "org.http4s" %% "http4s-circe" % http4sVersion, + "io.circe" %% "circe-generic" % "0.14.2" + ) + ) + .dependsOn(blazeServer, blazeClient) /* Helper Functions */ diff --git a/examples/src/main/resources/beerbottle.png b/examples/src/main/resources/beerbottle.png new file mode 100644 index 000000000..607bcfe25 Binary files /dev/null and b/examples/src/main/resources/beerbottle.png differ diff --git a/examples/src/main/resources/logback.xml b/examples/src/main/resources/logback.xml new file mode 100644 index 000000000..6b246ee13 --- /dev/null +++ b/examples/src/main/resources/logback.xml @@ -0,0 +1,14 @@ + + + + + + %d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n + + + + + + + diff --git a/examples/src/main/resources/server.jks b/examples/src/main/resources/server.jks new file mode 100644 index 000000000..f12d00673 Binary files /dev/null and b/examples/src/main/resources/server.jks differ diff --git a/examples/src/main/scala/com/example/http4s/ExampleService.scala b/examples/src/main/scala/com/example/http4s/ExampleService.scala new file mode 100644 index 000000000..b2ab233a8 --- /dev/null +++ b/examples/src/main/scala/com/example/http4s/ExampleService.scala @@ -0,0 +1,212 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.example.http4s + +import cats.effect._ +import cats.syntax.all._ +import fs2.Stream +import io.circe.Json +import org.http4s._ +import org.http4s.circe._ +import org.http4s.dsl.Http4sDsl +import org.http4s.headers._ +import org.http4s.multipart.Multipart +import org.http4s.server._ +import org.http4s.server.middleware.authentication.BasicAuth +import org.http4s.server.middleware.authentication.BasicAuth.BasicAuthenticator +import org.http4s.syntax.all._ + +import scala.concurrent.duration._ + +class ExampleService[F[_]](implicit F: Async[F]) extends Http4sDsl[F] { + // A Router can mount multiple services to prefixes. The request is passed to the + // service with the longest matching prefix. + def routes: HttpRoutes[F] = + Router[F]( + "" -> rootRoutes, + "/auth" -> authRoutes + ) + + def rootRoutes: HttpRoutes[F] = + HttpRoutes.of[F] { + case GET -> Root => + // disabled until twirl supports dotty + // Supports Play Framework template -- see src/main/twirl. + // Ok(html.index()) + Ok("Hello World") + + case _ -> Root => + // The default route result is NotFound. Sometimes MethodNotAllowed is more appropriate. + MethodNotAllowed(Allow(GET)) + + case GET -> Root / "ping" => + // EntityEncoder allows for easy conversion of types to a response body + Ok("pong") + + case GET -> Root / "streaming" => + // It's also easy to stream responses to clients + Ok(dataStream(100)) + + case req @ GET -> Root / "ip" => + // It's possible to define an EntityEncoder anywhere so you're not limited to built in types + val json = Json.obj("origin" -> Json.fromString(req.remoteAddr.fold("unknown")(_.toString))) + Ok(json) + + case GET -> Root / "redirect" => + // Not every response must be Ok using a EntityEncoder: some have meaning only for specific types + TemporaryRedirect(Location(uri"/http4s/")) + + case GET -> Root / "content-change" => + // EntityEncoder typically deals with appropriate headers, but they can be overridden + Ok("

This will have an html content type!

", `Content-Type`(MediaType.text.html)) + + case req @ GET -> "static" /: path => + // captures everything after "/static" into `path` + // Try http://localhost:8080/http4s/static/nasa_blackhole_image.jpg + // See also org.http4s.server.staticcontent to create a mountable service for static content + StaticFile.fromResource(path.toString, Some(req)).getOrElseF(NotFound()) + + // ///////////////////////////////////////////////////////////// + // ////////////// Dealing with the message body //////////////// + case req @ POST -> Root / "echo" => + // The body can be used in the response + Ok(req.body).map(_.putHeaders(`Content-Type`(MediaType.text.plain))) + + case GET -> Root / "echo" => + // disabled until twirl supports dotty + // Ok(html.submissionForm("echo data")) + Ok("Hello World") + + case req @ POST -> Root / "echo2" => + // Even more useful, the body can be transformed in the response + Ok(req.body.drop(6), `Content-Type`(MediaType.text.plain)) + + case GET -> Root / "echo2" => + // disabled until twirl supports dotty + // Ok(html.submissionForm("echo data")) + Ok("Hello World") + + case req @ POST -> Root / "sum" => + // EntityDecoders allow turning the body into something useful + req + .decode[UrlForm] { data => + data.values.get("sum").flatMap(_.uncons) match { + case Some((s, _)) => + val sum = s.split(' ').filter(_.nonEmpty).map(_.trim.toInt).sum + Ok(sum.toString) + + case None => BadRequest(s"Invalid data: " + data) + } + } + .handleErrorWith { // We can handle errors using effect methods + case e: NumberFormatException => BadRequest("Not an int: " + e.getMessage) + } + + case GET -> Root / "sum" => + // disabled until twirl supports dotty + // Ok(html.submissionForm("sum")) + Ok("Hello World") + + // ///////////////////////////////////////////////////////////// + // //////////////////// Blaze examples ///////////////////////// + + // You can use the same service for GET and HEAD. For HEAD request, + // only the Content-Length is sent (if static content) + case GET -> Root / "helloworld" => + helloWorldService + case HEAD -> Root / "helloworld" => + helloWorldService + + // HEAD responses with Content-Length, but empty content + case HEAD -> Root / "head" => + Ok("", `Content-Length`.unsafeFromLong(1024)) + + // Response with invalid Content-Length header generates + // an error (underflow causes the connection to be closed) + case GET -> Root / "underflow" => + Ok("foo", `Content-Length`.unsafeFromLong(4)) + + // Response with invalid Content-Length header generates + // an error (overflow causes the extra bytes to be ignored) + case GET -> Root / "overflow" => + Ok("foo", `Content-Length`.unsafeFromLong(2)) + + // ///////////////////////////////////////////////////////////// + // ////////////// Form encoding example //////////////////////// + case GET -> Root / "form-encoded" => + // disabled until twirl supports dotty + // Ok(html.formEncoded()) + Ok("Hello World") + + case req @ POST -> Root / "form-encoded" => + // EntityDecoders return an F[A] which is easy to sequence + req.decode[UrlForm] { m => + val s = m.values.mkString("\n") + Ok(s"Form Encoded Data\n$s") + } + + // ///////////////////////////////////////////////////////////// + // ////////////////////// Multi Part ////////////////////////// + case GET -> Root / "form" => + // disabled until twirl supports dotty + // Ok(html.form()) + Ok("Hello World") + + case req @ POST -> Root / "multipart" => + req.decode[Multipart[F]] { m => + Ok(s"""Multipart Data\nParts:${m.parts.length}\n${m.parts.map(_.name).mkString("\n")}""") + } + } + + def helloWorldService: F[Response[F]] = Ok("Hello World!") + + // This is a mock data source, but could be a Process representing results from a database + def dataStream(n: Int)(implicit clock: Clock[F]): Stream[F, String] = { + val interval = 100.millis + val stream = Stream + .awakeEvery[F](interval) + .evalMap(_ => clock.realTime) + .map(time => s"Current system time: $time ms\n") + .take(n.toLong) + + Stream.emit(s"Starting $interval stream intervals, taking $n results\n\n") ++ stream + } + + // Services can be protected using HTTP authentication. + val realm = "testrealm" + + val authStore: BasicAuthenticator[F, String] = (creds: BasicCredentials) => + if (creds.username == "username" && creds.password == "password") F.pure(Some(creds.username)) + else F.pure(None) + + // An AuthedRoutes[A, F] is a Service[F, (A, Request[F]), Response[F]] for some + // user type A. `BasicAuth` is an auth middleware, which binds an + // AuthedRoutes to an authentication store. + val basicAuth: AuthMiddleware[F, String] = BasicAuth(realm, authStore) + + def authRoutes: HttpRoutes[F] = + basicAuth(AuthedRoutes.of[String, F] { + // AuthedRoutes look like HttpRoutes, but the user is extracted with `as`. + case GET -> Root / "protected" as user => + Ok(s"This page is protected using HTTP authentication; logged in as $user") + }) +} + +object ExampleService { + def apply[F[_]: Async]: ExampleService[F] = + new ExampleService[F] +} diff --git a/examples/src/main/scala/com/example/http4s/HeaderExamples.scala b/examples/src/main/scala/com/example/http4s/HeaderExamples.scala new file mode 100644 index 000000000..84b3b46a4 --- /dev/null +++ b/examples/src/main/scala/com/example/http4s/HeaderExamples.scala @@ -0,0 +1,106 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.examples.http4s + +import cats.Semigroup +import cats.data.NonEmptyList +import cats.syntax.all._ +import org.http4s._ +import org.typelevel.ci._ + +// TODO migrate to a proper mdoc. This is to keep it compiling. + +object HeaderExamples { + // /// test for construction + final case class Foo(v: String) + object Foo { + implicit def headerFoo: Header[Foo, Header.Single] = new Header[Foo, Header.Single] { + def name = ci"foo" + def value(f: Foo) = f.v + def parse(s: String) = Foo(s).asRight + } + + } + def baz: Header.Raw = Header.Raw(ci"baz", "bbb") + + val myHeaders: Headers = Headers( + Foo("hello"), + "my" -> "header", + baz + ) + // //// test for selection + final case class Bar(v: NonEmptyList[String]) + object Bar { + implicit val headerBar: Header[Bar, Header.Recurring] with Semigroup[Bar] = + new Header[Bar, Header.Recurring] with Semigroup[Bar] { + def name = ci"Bar" + def value(b: Bar) = b.v.toList.mkString(",") + def parse(s: String) = Bar(NonEmptyList.one(s)).asRight + def combine(a: Bar, b: Bar) = Bar(a.v |+| b.v) + } + } + + final case class SetCookie(name: String, value: String) + object SetCookie { + implicit val headerCookie: Header[SetCookie, Header.Recurring] = + new Header[SetCookie, Header.Recurring] { + def name = ci"Set-Cookie" + def value(c: SetCookie) = s"${c.name}:${c.value}" + def parse(s: String) = + s.split(':').toList match { + case List(name, value) => SetCookie(name, value).asRight + case _ => Left(ParseFailure("Malformed cookie", "")) + } + } + } + + val hs: Headers = Headers( + Bar(NonEmptyList.one("one")), + Foo("two"), + SetCookie("cookie1", "a cookie"), + Bar(NonEmptyList.one("three")), + SetCookie("cookie2", "another cookie") + ) + + val a: Option[Foo] = hs.get[Foo] + val b: Option[Bar] = hs.get[Bar] + val c: Option[NonEmptyList[SetCookie]] = hs.get[SetCookie] + + // scala> Examples.a + // val res0: Option[Foo] = Some(Foo(two)) + + // scala> Examples.b + // val res1: Option[Bar] = Some(Bar(NonEmptyList(one, three))) + + // scala> Examples.c + // val res2: Option[NonEmptyList[SetCookie]] = Some(NonEmptyList(SetCookie(cookie1,a cookie), SetCookie(cookie2,another cookie))) + + val hs2: Headers = Headers( + Bar(NonEmptyList.one("one")), + Foo("two"), + SetCookie("cookie1", "a cookie"), + Bar(NonEmptyList.one("three")), + SetCookie("cookie2", "another cookie"), + "a" -> "b", + Option("a" -> "c"), + List("a" -> "c"), + List(SetCookie("cookie3", "cookie three")) + // , + // Option(List("a" -> "c")) // correctly fails to compile + ) + +} diff --git a/examples/src/main/scala/com/example/http4s/blaze/BlazeExample.scala b/examples/src/main/scala/com/example/http4s/blaze/BlazeExample.scala new file mode 100644 index 000000000..c4b000195 --- /dev/null +++ b/examples/src/main/scala/com/example/http4s/blaze/BlazeExample.scala @@ -0,0 +1,44 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.example.http4s.blaze + +import cats.effect._ +import com.example.http4s.ExampleService +import org.http4s.HttpApp +import org.http4s.blaze.server.BlazeServerBuilder +import org.http4s.server.Router +import org.http4s.server.Server + +object BlazeExample extends IOApp { + override def run(args: List[String]): IO[ExitCode] = + BlazeExampleApp.resource[IO].use(_ => IO.never).as(ExitCode.Success) +} + +object BlazeExampleApp { + def httpApp[F[_]: Async]: HttpApp[F] = + Router( + "/http4s" -> ExampleService[F].routes + ).orNotFound + + def resource[F[_]: Async]: Resource[F, Server] = { + val app = httpApp[F] + BlazeServerBuilder[F] + .bindHttp(8080) + .withHttpApp(app) + .resource + } +} diff --git a/examples/src/main/scala/com/example/http4s/blaze/BlazeHttp2Example.scala b/examples/src/main/scala/com/example/http4s/blaze/BlazeHttp2Example.scala new file mode 100644 index 000000000..ede23cad0 --- /dev/null +++ b/examples/src/main/scala/com/example/http4s/blaze/BlazeHttp2Example.scala @@ -0,0 +1,25 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.example.http4s +package blaze + +import cats.effect._ + +object BlazeHttp2Example extends IOApp { + override def run(args: List[String]): IO[ExitCode] = + BlazeSslExampleApp.builder[IO].flatMap(_.enableHttp2(true).serve.compile.lastOrError) +} diff --git a/examples/src/main/scala/com/example/http4s/blaze/BlazeSslExample.scala b/examples/src/main/scala/com/example/http4s/blaze/BlazeSslExample.scala new file mode 100644 index 000000000..9ec72b6ac --- /dev/null +++ b/examples/src/main/scala/com/example/http4s/blaze/BlazeSslExample.scala @@ -0,0 +1,48 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.example.http4s +package blaze + +import cats.effect._ +import cats.syntax.all._ +import org.http4s.blaze.server.BlazeServerBuilder +import org.http4s.server.Server + +import javax.net.ssl.SSLContext + +object BlazeSslExample extends IOApp { + override def run(args: List[String]): IO[ExitCode] = + BlazeSslExampleApp.resource[IO].use(_ => IO.never).as(ExitCode.Success) +} + +object BlazeSslExampleApp { + def context[F[_]: Sync]: F[SSLContext] = + ssl.loadContextFromClasspath(ssl.keystorePassword, ssl.keyManagerPassword) + + def builder[F[_]: Async]: F[BlazeServerBuilder[F]] = + context.map { sslContext => + BlazeServerBuilder[F] + .bindHttp(8443) + .withSslContext(sslContext) + } + + def resource[F[_]: Async]: Resource[F, Server] = + for { + b <- Resource.eval(builder[F]) + server <- b.withHttpApp(BlazeExampleApp.httpApp).resource + } yield server +} diff --git a/examples/src/main/scala/com/example/http4s/blaze/BlazeSslExampleWithRedirect.scala b/examples/src/main/scala/com/example/http4s/blaze/BlazeSslExampleWithRedirect.scala new file mode 100644 index 000000000..8a04438a6 --- /dev/null +++ b/examples/src/main/scala/com/example/http4s/blaze/BlazeSslExampleWithRedirect.scala @@ -0,0 +1,44 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.example.http4s +package blaze + +import cats.effect._ +import fs2._ +import org.http4s.blaze.server.BlazeServerBuilder + +object BlazeSslExampleWithRedirect extends IOApp { + import BlazeSslExampleWithRedirectApp._ + + override def run(args: List[String]): IO[ExitCode] = + sslStream[IO] + .mergeHaltBoth(redirectStream[IO]) + .compile + .drain + .as(ExitCode.Success) +} + +object BlazeSslExampleWithRedirectApp { + def redirectStream[F[_]: Async]: Stream[F, ExitCode] = + BlazeServerBuilder[F] + .bindHttp(8080) + .withHttpApp(ssl.redirectApp(8443)) + .serve + + def sslStream[F[_]: Async]: Stream[F, ExitCode] = + Stream.eval(BlazeSslExampleApp.builder[F]).flatMap(_.serve) +} diff --git a/examples/src/main/scala/com/example/http4s/blaze/BlazeWebSocketExample.scala b/examples/src/main/scala/com/example/http4s/blaze/BlazeWebSocketExample.scala new file mode 100644 index 000000000..7324acb78 --- /dev/null +++ b/examples/src/main/scala/com/example/http4s/blaze/BlazeWebSocketExample.scala @@ -0,0 +1,91 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.example.http4s.blaze + +import cats.effect._ +import cats.effect.std.Queue +import cats.syntax.all._ +import fs2._ +import org.http4s._ +import org.http4s.blaze.server.BlazeServerBuilder +import org.http4s.dsl.Http4sDsl +import org.http4s.implicits._ +import org.http4s.server.websocket._ +import org.http4s.websocket.WebSocketFrame +import org.http4s.websocket.WebSocketFrame._ + +import scala.concurrent.duration._ + +object BlazeWebSocketExample extends IOApp { + override def run(args: List[String]): IO[ExitCode] = + BlazeWebSocketExampleApp[IO].stream.compile.drain.as(ExitCode.Success) +} + +class BlazeWebSocketExampleApp[F[_]](implicit F: Async[F]) extends Http4sDsl[F] { + def routes(wsb: WebSocketBuilder2[F]): HttpRoutes[F] = + HttpRoutes.of[F] { + case GET -> Root / "hello" => + Ok("Hello world.") + + case GET -> Root / "ws" => + val toClient: Stream[F, WebSocketFrame] = + Stream.awakeEvery[F](1.seconds).map(d => Text(s"Ping! $d")) + val fromClient: Pipe[F, WebSocketFrame, Unit] = _.evalMap { + case Text(t, _) => F.delay(println(t)) + case f => F.delay(println(s"Unknown type: $f")) + } + wsb.build(toClient, fromClient) + + case GET -> Root / "wsecho" => + val echoReply: Pipe[F, WebSocketFrame, WebSocketFrame] = + _.collect { + case Text(msg, _) => Text("You sent the server: " + msg) + case _ => Text("Something new") + } + + /* Note that this use of a queue is not typical of http4s applications. + * This creates a single queue to connect the input and output activity + * on the WebSocket together. The queue is therefore not accessible outside + * of the scope of this single HTTP request to connect a WebSocket. + * + * While this meets the contract of the service to echo traffic back to + * its source, many applications will want to create the queue object at + * a higher level and pass it into the "routes" method or the containing + * class constructor in order to share the queue (or some other concurrency + * object) across multiple requests, or to scope it to the application itself + * instead of to a request. + */ + Queue + .unbounded[F, Option[WebSocketFrame]] + .flatMap { q => + val d: Stream[F, WebSocketFrame] = Stream.fromQueueNoneTerminated(q).through(echoReply) + val e: Pipe[F, WebSocketFrame, Unit] = _.enqueueNoneTerminated(q) + wsb.build(d, e) + } + } + + def stream: Stream[F, ExitCode] = + BlazeServerBuilder[F] + .bindHttp(8080) + .withHttpWebSocketApp(routes(_).orNotFound) + .serve +} + +object BlazeWebSocketExampleApp { + def apply[F[_]: Async]: BlazeWebSocketExampleApp[F] = + new BlazeWebSocketExampleApp[F] +} diff --git a/examples/src/main/scala/com/example/http4s/blaze/ClientExample.scala b/examples/src/main/scala/com/example/http4s/blaze/ClientExample.scala new file mode 100644 index 000000000..85a7d9961 --- /dev/null +++ b/examples/src/main/scala/com/example/http4s/blaze/ClientExample.scala @@ -0,0 +1,69 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.example.http4s.blaze + +import cats.effect._ +import io.circe.generic.auto._ +import org.http4s.Status.NotFound +import org.http4s.Status.Successful +import org.http4s.blaze.client.BlazeClientBuilder +import org.http4s.circe._ +import org.http4s.client.Client +import org.http4s.syntax.all._ + +object ClientExample extends IOApp { + def printGooglePage(client: Client[IO]): IO[Unit] = { + val page: IO[String] = client.expect[String](uri"https://www.google.com/") + IO.parSequenceN(2)((1 to 2).toList.map { _ => + for { + // each execution of the effect will refetch the page! + pageContent <- page + firstBytes = pageContent.take(72) + _ <- IO.println(firstBytes) + } yield () + }).as(()) + } + + def matchOnResponseCode(client: Client[IO]): IO[Unit] = { + final case class Foo(bar: String) + + for { + // Match on response code! + page <- client.get(uri"http://http4s.org/resources/foo.json") { + case Successful(resp) => + // decodeJson is defined for Circe, just need the right decoder! + resp.decodeJson[Foo].map("Received response: " + _) + case NotFound(_) => IO.pure("Not Found!!!") + case resp => IO.pure("Failed: " + resp.status) + } + _ <- IO.println(page) + } yield () + } + + def getSite(client: Client[IO]): IO[Unit] = + for { + _ <- printGooglePage(client) + // We can do much more: how about decoding some JSON to a scala object + // after matching based on the response status code? + _ <- matchOnResponseCode(client) + } yield () + + def run(args: List[String]): IO[ExitCode] = + BlazeClientBuilder[IO].resource + .use(getSite) + .as(ExitCode.Success) +} diff --git a/examples/src/main/scala/com/example/http4s/blaze/ClientMultipartPostExample.scala b/examples/src/main/scala/com/example/http4s/blaze/ClientMultipartPostExample.scala new file mode 100644 index 000000000..48d79c4b1 --- /dev/null +++ b/examples/src/main/scala/com/example/http4s/blaze/ClientMultipartPostExample.scala @@ -0,0 +1,63 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.example.http4s.blaze + +import cats.effect.ExitCode +import cats.effect.IO +import cats.effect.IOApp +import org.http4s.Uri._ +import org.http4s._ +import org.http4s.blaze.client.BlazeClientBuilder +import org.http4s.client.Client +import org.http4s.client.dsl.Http4sClientDsl +import org.http4s.headers._ +import org.http4s.multipart._ +import org.http4s.syntax.literals._ + +import java.net.URL + +object ClientMultipartPostExample extends IOApp with Http4sClientDsl[IO] { + private val bottle: URL = getClass.getResource("/beerbottle.png") + + def go(client: Client[IO], multiparts: Multiparts[IO]): IO[String] = { + val url = Uri( + scheme = Some(Scheme.http), + authority = Some(Authority(host = RegName("httpbin.org"))), + path = path"/post" + ) + + multiparts + .multipart( + Vector( + Part.formData("text", "This is text."), + Part.fileData("BALL", bottle, `Content-Type`(MediaType.image.png)) + ) + ) + .flatMap { multipart => + val request: Request[IO] = Method.POST(multipart, url).withHeaders(multipart.headers) + client.expect[String](request) + } + } + + def run(args: List[String]): IO[ExitCode] = + Multiparts.forSync[IO].flatMap { multiparts => + BlazeClientBuilder[IO].resource + .use(go(_, multiparts)) + .flatMap(s => IO.println(s)) + .as(ExitCode.Success) + } +} diff --git a/examples/src/main/scala/com/example/http4s/blaze/ClientPostExample.scala b/examples/src/main/scala/com/example/http4s/blaze/ClientPostExample.scala new file mode 100644 index 000000000..c31684885 --- /dev/null +++ b/examples/src/main/scala/com/example/http4s/blaze/ClientPostExample.scala @@ -0,0 +1,32 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.example.http4s.blaze + +import cats.effect._ +import org.http4s._ +import org.http4s.blaze.client.BlazeClientBuilder +import org.http4s.client.dsl.Http4sClientDsl +import org.http4s.dsl.io._ +import org.http4s.syntax.all._ + +object ClientPostExample extends IOApp with Http4sClientDsl[IO] { + def run(args: List[String]): IO[ExitCode] = { + val req = POST(UrlForm("q" -> "http4s"), uri"https://duckduckgo.com/") + val responseBody = BlazeClientBuilder[IO].resource.use(_.expect[String](req)) + responseBody.flatMap(resp => IO.println(resp)).as(ExitCode.Success) + } +} diff --git a/examples/src/main/scala/com/example/http4s/blaze/demo/StreamUtils.scala b/examples/src/main/scala/com/example/http4s/blaze/demo/StreamUtils.scala new file mode 100644 index 000000000..2665eeea5 --- /dev/null +++ b/examples/src/main/scala/com/example/http4s/blaze/demo/StreamUtils.scala @@ -0,0 +1,33 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.example.http4s.blaze.demo + +import cats.effect.Sync +import fs2.Stream + +trait StreamUtils[F[_]] { + def evalF[A](thunk: => A)(implicit F: Sync[F]): Stream[F, A] = Stream.eval(F.delay(thunk)) + def putStrLn(value: String)(implicit F: Sync[F]): Stream[F, Unit] = evalF(println(value)) + def putStr(value: String)(implicit F: Sync[F]): Stream[F, Unit] = evalF(print(value)) + def env(name: String)(implicit F: Sync[F]): Stream[F, Option[String]] = evalF(sys.env.get(name)) + def error(msg: String)(implicit F: Sync[F]): Stream[F, String] = + Stream.raiseError(new Exception(msg)) +} + +object StreamUtils { + implicit def syncInstance[F[_]]: StreamUtils[F] = new StreamUtils[F] {} +} diff --git a/examples/src/main/scala/com/example/http4s/blaze/demo/client/MultipartClient.scala b/examples/src/main/scala/com/example/http4s/blaze/demo/client/MultipartClient.scala new file mode 100644 index 000000000..1d52f0474 --- /dev/null +++ b/examples/src/main/scala/com/example/http4s/blaze/demo/client/MultipartClient.scala @@ -0,0 +1,69 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.example.http4s.blaze.demo.client + +import cats.effect.ExitCode +import cats.effect.IO +import cats.effect.IOApp +import cats.effect.Resource +import com.example.http4s.blaze.demo.StreamUtils +import fs2.Stream +import org.http4s.Method._ +import org.http4s._ +import org.http4s.blaze.client.BlazeClientBuilder +import org.http4s.client.Client +import org.http4s.client.dsl.Http4sClientDsl +import org.http4s.headers.`Content-Type` +import org.http4s.implicits._ +import org.http4s.multipart.Multiparts +import org.http4s.multipart.Part + +import java.net.URL + +object MultipartClient extends MultipartHttpClient + +class MultipartHttpClient(implicit S: StreamUtils[IO]) extends IOApp with Http4sClientDsl[IO] { + private val image: IO[URL] = IO.blocking(getClass.getResource("/beerbottle.png")) + + private def request(multiparts: Multiparts[IO]) = + for { + url <- image + body <- multiparts.multipart( + Vector( + Part.formData("name", "gvolpe"), + Part.fileData("rick", url, `Content-Type`(MediaType.image.png)) + ) + ) + } yield POST(body, uri"http://localhost:8080/v1/multipart").withHeaders(body.headers) + + private val resources: Resource[IO, (Client[IO], Multiparts[IO])] = + for { + client <- BlazeClientBuilder[IO].resource + multiparts <- Resource.eval(Multiparts.forSync[IO]) + } yield (client, multiparts) + + private val example = + for { + (client, multiparts) <- Stream.resource(resources) + req <- Stream.eval(request(multiparts)) + value <- Stream.eval(client.expect[String](req)) + _ <- S.putStrLn(value) + } yield () + + override def run(args: List[String]): IO[ExitCode] = + example.compile.drain.as(ExitCode.Success) +} diff --git a/examples/src/main/scala/com/example/http4s/blaze/demo/client/StreamClient.scala b/examples/src/main/scala/com/example/http4s/blaze/demo/client/StreamClient.scala new file mode 100644 index 000000000..a8541029c --- /dev/null +++ b/examples/src/main/scala/com/example/http4s/blaze/demo/client/StreamClient.scala @@ -0,0 +1,51 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.example.http4s.blaze.demo.client + +import cats.effect.Async +import cats.effect.ExitCode +import cats.effect.IO +import cats.effect.IOApp +import com.example.http4s.blaze.demo.StreamUtils +import io.circe.Json +import org.http4s.Request +import org.http4s.blaze.client.BlazeClientBuilder +import org.http4s.syntax.literals._ +import org.typelevel.jawn.Facade + +object StreamClient extends IOApp { + def run(args: List[String]): IO[ExitCode] = + new HttpClient[IO].run.as(ExitCode.Success) +} + +class HttpClient[F[_]](implicit F: Async[F], S: StreamUtils[F]) { + implicit val jsonFacade: Facade[Json] = + new io.circe.jawn.CirceSupportParser(None, false).facade + + def run: F[Unit] = + BlazeClientBuilder[F].stream + .flatMap { client => + val request = + Request[F](uri = uri"http://localhost:8080/v1/dirs?depth=3") + for { + response <- client.stream(request).flatMap(_.body.chunks.through(fs2.text.utf8.decodeC)) + _ <- S.putStr(response) + } yield () + } + .compile + .drain +} diff --git a/examples/src/main/scala/com/example/http4s/blaze/demo/server/Module.scala b/examples/src/main/scala/com/example/http4s/blaze/demo/server/Module.scala new file mode 100644 index 000000000..6fce19b91 --- /dev/null +++ b/examples/src/main/scala/com/example/http4s/blaze/demo/server/Module.scala @@ -0,0 +1,77 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.example.http4s.blaze.demo.server + +import cats.data.OptionT +import cats.effect._ +import cats.syntax.semigroupk._ +import com.example.http4s.blaze.demo.server.endpoints._ +import com.example.http4s.blaze.demo.server.endpoints.auth.BasicAuthHttpEndpoint +import com.example.http4s.blaze.demo.server.endpoints.auth.GitHubHttpEndpoint +import com.example.http4s.blaze.demo.server.service.FileService +import com.example.http4s.blaze.demo.server.service.GitHubService +import org.http4s.HttpRoutes +import org.http4s.client.Client +import org.http4s.server.HttpMiddleware +import org.http4s.server.middleware.AutoSlash +import org.http4s.server.middleware.ChunkAggregator +import org.http4s.server.middleware.GZip +import org.http4s.server.middleware.Timeout + +import scala.concurrent.duration._ + +class Module[F[_]: Async](client: Client[F]) { + private val fileService = new FileService[F] + + private val gitHubService = new GitHubService[F](client) + + def middleware: HttpMiddleware[F] = { (routes: HttpRoutes[F]) => + GZip(routes) + }.compose(routes => AutoSlash(routes)) + + val fileHttpEndpoint: HttpRoutes[F] = + new FileHttpEndpoint[F](fileService).service + + val nonStreamFileHttpEndpoint: HttpRoutes[F] = + ChunkAggregator(OptionT.liftK[F])(fileHttpEndpoint) + + private val hexNameHttpEndpoint: HttpRoutes[F] = + new HexNameHttpEndpoint[F].service + + private val compressedEndpoints: HttpRoutes[F] = + middleware(hexNameHttpEndpoint) + + private val timeoutHttpEndpoint: HttpRoutes[F] = + new TimeoutHttpEndpoint[F].service + + private val timeoutEndpoints: HttpRoutes[F] = + Timeout(1.second)(timeoutHttpEndpoint) + + private val multipartHttpEndpoint: HttpRoutes[F] = + new MultipartHttpEndpoint[F](fileService).service + + private val gitHubHttpEndpoint: HttpRoutes[F] = + new GitHubHttpEndpoint[F](gitHubService).service + + val basicAuthHttpEndpoint: HttpRoutes[F] = + new BasicAuthHttpEndpoint[F].service + + val httpServices: HttpRoutes[F] = ( + compressedEndpoints <+> timeoutEndpoints <+> multipartHttpEndpoint + <+> gitHubHttpEndpoint + ) +} diff --git a/examples/src/main/scala/com/example/http4s/blaze/demo/server/Server.scala b/examples/src/main/scala/com/example/http4s/blaze/demo/server/Server.scala new file mode 100644 index 000000000..5790991c4 --- /dev/null +++ b/examples/src/main/scala/com/example/http4s/blaze/demo/server/Server.scala @@ -0,0 +1,49 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.example.http4s.blaze.demo.server + +import cats.effect._ +import fs2.Stream +import org.http4s.HttpApp +import org.http4s.blaze.client.BlazeClientBuilder +import org.http4s.blaze.server.BlazeServerBuilder +import org.http4s.server.Router + +object Server extends IOApp { + override def run(args: List[String]): IO[ExitCode] = + HttpServer.stream[IO].compile.drain.as(ExitCode.Success) +} + +object HttpServer { + def httpApp[F[_]: Sync](ctx: Module[F]): HttpApp[F] = + Router( + s"/${endpoints.ApiVersion}/protected" -> ctx.basicAuthHttpEndpoint, + s"/${endpoints.ApiVersion}" -> ctx.fileHttpEndpoint, + s"/${endpoints.ApiVersion}/nonstream" -> ctx.nonStreamFileHttpEndpoint, + "/" -> ctx.httpServices + ).orNotFound + + def stream[F[_]: Async]: Stream[F, ExitCode] = + for { + client <- BlazeClientBuilder[F].stream + ctx <- Stream(new Module[F](client)) + exitCode <- BlazeServerBuilder[F] + .bindHttp(8080) + .withHttpApp(httpApp(ctx)) + .serve + } yield exitCode +} diff --git a/examples/src/main/scala/com/example/http4s/blaze/demo/server/endpoints/FileHttpEndpoint.scala b/examples/src/main/scala/com/example/http4s/blaze/demo/server/endpoints/FileHttpEndpoint.scala new file mode 100644 index 000000000..43b9f6aac --- /dev/null +++ b/examples/src/main/scala/com/example/http4s/blaze/demo/server/endpoints/FileHttpEndpoint.scala @@ -0,0 +1,31 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.example.http4s.blaze.demo.server.endpoints + +import cats.effect.Sync +import com.example.http4s.blaze.demo.server.service.FileService +import org.http4s._ +import org.http4s.dsl.Http4sDsl + +class FileHttpEndpoint[F[_]: Sync](fileService: FileService[F]) extends Http4sDsl[F] { + object DepthQueryParamMatcher extends OptionalQueryParamDecoderMatcher[Int]("depth") + + val service: HttpRoutes[F] = HttpRoutes.of[F] { + case GET -> Root / "dirs" :? DepthQueryParamMatcher(depth) => + Ok(fileService.homeDirectories(depth)) + } +} diff --git a/examples/src/main/scala/com/example/http4s/blaze/demo/server/endpoints/HexNameHttpEndpoint.scala b/examples/src/main/scala/com/example/http4s/blaze/demo/server/endpoints/HexNameHttpEndpoint.scala new file mode 100644 index 000000000..c0cca0b42 --- /dev/null +++ b/examples/src/main/scala/com/example/http4s/blaze/demo/server/endpoints/HexNameHttpEndpoint.scala @@ -0,0 +1,30 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.example.http4s.blaze.demo.server.endpoints + +import cats.effect.Sync +import org.http4s.dsl.Http4sDsl +import org.http4s.{ApiVersion => _, _} + +class HexNameHttpEndpoint[F[_]: Sync] extends Http4sDsl[F] { + object NameQueryParamMatcher extends QueryParamDecoderMatcher[String]("name") + + val service: HttpRoutes[F] = HttpRoutes.of { + case GET -> Root / ApiVersion / "hex" :? NameQueryParamMatcher(name) => + Ok(name.getBytes("UTF-8").map("%02x".format(_)).mkString) + } +} diff --git a/examples/src/main/scala/com/example/http4s/blaze/demo/server/endpoints/MultipartHttpEndpoint.scala b/examples/src/main/scala/com/example/http4s/blaze/demo/server/endpoints/MultipartHttpEndpoint.scala new file mode 100644 index 000000000..f14ce4002 --- /dev/null +++ b/examples/src/main/scala/com/example/http4s/blaze/demo/server/endpoints/MultipartHttpEndpoint.scala @@ -0,0 +1,42 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.example.http4s.blaze.demo.server.endpoints + +import cats.effect.Concurrent +import cats.syntax.all._ +import com.example.http4s.blaze.demo.server.service.FileService +import org.http4s.EntityDecoder.multipart +import org.http4s.dsl.Http4sDsl +import org.http4s.multipart.Part +import org.http4s.{ApiVersion => _, _} + +class MultipartHttpEndpoint[F[_]: Concurrent](fileService: FileService[F]) extends Http4sDsl[F] { + val service: HttpRoutes[F] = HttpRoutes.of { + case GET -> Root / ApiVersion / "multipart" => + Ok("Send a file (image, sound, etc) via POST Method") + + case req @ POST -> Root / ApiVersion / "multipart" => + req.decodeWith(multipart[F], strict = true) { response => + def filterFileTypes(part: Part[F]): Boolean = + part.headers.headers.exists(_.value.contains("filename")) + + val stream = response.parts.filter(filterFileTypes).traverse(fileService.store) + + Ok(stream.map(_ => s"Multipart file parsed successfully > ${response.parts}")) + } + } +} diff --git a/examples/src/main/scala/com/example/http4s/blaze/demo/server/endpoints/TimeoutHttpEndpoint.scala b/examples/src/main/scala/com/example/http4s/blaze/demo/server/endpoints/TimeoutHttpEndpoint.scala new file mode 100644 index 000000000..cdca72a7b --- /dev/null +++ b/examples/src/main/scala/com/example/http4s/blaze/demo/server/endpoints/TimeoutHttpEndpoint.scala @@ -0,0 +1,33 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.example.http4s.blaze.demo.server.endpoints + +import cats.effect.Async +import cats.syntax.all._ +import org.http4s.dsl.Http4sDsl +import org.http4s.{ApiVersion => _, _} + +import java.util.concurrent.TimeUnit +import scala.concurrent.duration.FiniteDuration +import scala.util.Random + +class TimeoutHttpEndpoint[F[_]](implicit F: Async[F]) extends Http4sDsl[F] { + val service: HttpRoutes[F] = HttpRoutes.of { case GET -> Root / ApiVersion / "timeout" => + val randomDuration = FiniteDuration(Random.nextInt(3) * 1000L, TimeUnit.MILLISECONDS) + F.sleep(randomDuration) *> Ok("delayed response") + } +} diff --git a/examples/src/main/scala/com/example/http4s/blaze/demo/server/endpoints/auth/AuthRepository.scala b/examples/src/main/scala/com/example/http4s/blaze/demo/server/endpoints/auth/AuthRepository.scala new file mode 100644 index 000000000..166c7764d --- /dev/null +++ b/examples/src/main/scala/com/example/http4s/blaze/demo/server/endpoints/auth/AuthRepository.scala @@ -0,0 +1,39 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.example.http4s.blaze.demo.server.endpoints.auth + +import cats.effect.Sync +import cats.syntax.apply._ +import org.http4s.BasicCredentials + +trait AuthRepository[F[_], A] { + def persist(entity: A): F[Unit] + def find(entity: A): F[Option[A]] +} + +object AuthRepository { + implicit def authUserRepo[F[_]](implicit F: Sync[F]): AuthRepository[F, BasicCredentials] = + new AuthRepository[F, BasicCredentials] { + private val storage = scala.collection.mutable.Set[BasicCredentials]( + BasicCredentials("gvolpe", "123456") + ) + override def persist(entity: BasicCredentials): F[Unit] = + F.delay(storage.add(entity)) *> F.unit + override def find(entity: BasicCredentials): F[Option[BasicCredentials]] = + F.delay(storage.find(_ == entity)) + } +} diff --git a/examples/src/main/scala/com/example/http4s/blaze/demo/server/endpoints/auth/BasicAuthHttpEndpoint.scala b/examples/src/main/scala/com/example/http4s/blaze/demo/server/endpoints/auth/BasicAuthHttpEndpoint.scala new file mode 100644 index 000000000..b0ff9251b --- /dev/null +++ b/examples/src/main/scala/com/example/http4s/blaze/demo/server/endpoints/auth/BasicAuthHttpEndpoint.scala @@ -0,0 +1,37 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.example.http4s.blaze.demo.server.endpoints.auth + +import cats.effect.Sync +import org.http4s._ +import org.http4s.dsl.Http4sDsl +import org.http4s.server.AuthMiddleware +import org.http4s.server.middleware.authentication.BasicAuth + +// Use this header --> Authorization: Basic Z3ZvbHBlOjEyMzQ1Ng== +class BasicAuthHttpEndpoint[F[_]](implicit F: Sync[F], R: AuthRepository[F, BasicCredentials]) + extends Http4sDsl[F] { + private val authedRoutes: AuthedRoutes[BasicCredentials, F] = AuthedRoutes.of { + case GET -> Root as user => + Ok(s"Access Granted: ${user.username}") + } + + private val authMiddleware: AuthMiddleware[F, BasicCredentials] = + BasicAuth[F, BasicCredentials]("Protected Realm", R.find) + + val service: HttpRoutes[F] = authMiddleware(authedRoutes) +} diff --git a/examples/src/main/scala/com/example/http4s/blaze/demo/server/endpoints/auth/GitHubHttpEndpoint.scala b/examples/src/main/scala/com/example/http4s/blaze/demo/server/endpoints/auth/GitHubHttpEndpoint.scala new file mode 100644 index 000000000..d6af5c0e0 --- /dev/null +++ b/examples/src/main/scala/com/example/http4s/blaze/demo/server/endpoints/auth/GitHubHttpEndpoint.scala @@ -0,0 +1,43 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.example.http4s.blaze.demo.server.endpoints.auth + +import cats.effect.Sync +import cats.syntax.flatMap._ +import cats.syntax.functor._ +import com.example.http4s.blaze.demo.server.endpoints.ApiVersion +import com.example.http4s.blaze.demo.server.service.GitHubService +import org.http4s._ +import org.http4s.dsl.Http4sDsl + +class GitHubHttpEndpoint[F[_]](gitHubService: GitHubService[F])(implicit F: Sync[F]) + extends Http4sDsl[F] { + object CodeQuery extends QueryParamDecoderMatcher[String]("code") + object StateQuery extends QueryParamDecoderMatcher[String]("state") + + val service: HttpRoutes[F] = HttpRoutes.of { + case GET -> Root / ApiVersion / "github" => + Ok(gitHubService.authorize) + + // OAuth2 Callback URI + case GET -> Root / ApiVersion / "login" / "github" :? CodeQuery(code) :? StateQuery(state) => + for { + o <- Ok() + code <- gitHubService.accessToken(code, state).flatMap(gitHubService.userData) + } yield o.withEntity(code).putHeaders("Content-Type" -> "application/json") + } +} diff --git a/examples/src/main/scala/com/example/http4s/blaze/demo/server/endpoints/package.scala b/examples/src/main/scala/com/example/http4s/blaze/demo/server/endpoints/package.scala new file mode 100644 index 000000000..4ae78d959 --- /dev/null +++ b/examples/src/main/scala/com/example/http4s/blaze/demo/server/endpoints/package.scala @@ -0,0 +1,21 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.example.http4s.blaze.demo.server + +package object endpoints { + val ApiVersion = "v1" +} diff --git a/examples/src/main/scala/com/example/http4s/blaze/demo/server/service/FileService.scala b/examples/src/main/scala/com/example/http4s/blaze/demo/server/service/FileService.scala new file mode 100644 index 000000000..4d0035101 --- /dev/null +++ b/examples/src/main/scala/com/example/http4s/blaze/demo/server/service/FileService.scala @@ -0,0 +1,60 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.example.http4s.blaze.demo.server.service + +import cats.effect.Async +import com.example.http4s.blaze.demo.StreamUtils +import fs2.Stream +import fs2.io.file.Files +import fs2.io.file.Path +import org.http4s.multipart.Part + +import java.io.File +import java.nio.file.Paths + +class FileService[F[_]](implicit F: Async[F], S: StreamUtils[F]) { + def homeDirectories(depth: Option[Int]): Stream[F, String] = + S.env("HOME").flatMap { maybePath => + val ifEmpty = S.error("HOME environment variable not found!") + maybePath.fold(ifEmpty)(directories(_, depth.getOrElse(1))) + } + + def directories(path: String, depth: Int): Stream[F, String] = { + def dir(f: File, d: Int): Stream[F, File] = { + val dirs = Stream.emits(f.listFiles().toSeq).filter(_.isDirectory) + + if (d <= 0) Stream.empty + else if (d == 1) dirs + else dirs ++ dirs.flatMap(x => dir(x, d - 1)) + } + + S.evalF(new File(path)).flatMap { file => + dir(file, depth) + .map(_.getName) + .filter(!_.startsWith(".")) + .intersperse("\n") + } + } + + def store(part: Part[F]): Stream[F, Unit] = + for { + home <- S.evalF(sys.env.getOrElse("HOME", "/tmp")) + filename <- S.evalF(part.filename.getOrElse("sample")) + path <- S.evalF(Paths.get(s"$home/$filename")) + result <- part.body.through(Files[F].writeAll(Path.fromNioPath(path))) + } yield result +} diff --git a/examples/src/main/scala/com/example/http4s/blaze/demo/server/service/GitHubService.scala b/examples/src/main/scala/com/example/http4s/blaze/demo/server/service/GitHubService.scala new file mode 100644 index 000000000..3a84f569d --- /dev/null +++ b/examples/src/main/scala/com/example/http4s/blaze/demo/server/service/GitHubService.scala @@ -0,0 +1,71 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.example.http4s.blaze.demo.server.service + +import cats.effect.Concurrent +import cats.syntax.functor._ +import com.example.http4s.blaze.demo.server.endpoints.ApiVersion +import fs2.Stream +import io.circe.generic.auto._ +import org.http4s.Request +import org.http4s.circe._ +import org.http4s.client.Client +import org.http4s.client.dsl.Http4sClientDsl +import org.http4s.syntax.literals._ + +// See: https://developer.github.com/apps/building-oauth-apps/authorization-options-for-oauth-apps/#web-application-flow +class GitHubService[F[_]: Concurrent](client: Client[F]) extends Http4sClientDsl[F] { + // NEVER make this data public! This is just a demo! + private val ClientId = "959ea01cd3065cad274a" + private val ClientSecret = "53901db46451977e6331432faa2616ba24bc2550" + + private val RedirectUri = s"http://localhost:8080/$ApiVersion/login/github" + + private case class AccessTokenResponse(access_token: String) + + val authorize: Stream[F, Byte] = { + val uri = uri"https://github.com" + .withPath(path"/login/oauth/authorize") + .withQueryParam("client_id", ClientId) + .withQueryParam("redirect_uri", RedirectUri) + .withQueryParam("scopes", "public_repo") + .withQueryParam("state", "test_api") + + client.stream(Request[F](uri = uri)).flatMap(_.body) + } + + def accessToken(code: String, state: String): F[String] = { + val uri = uri"https://github.com" + .withPath(path"/login/oauth/access_token") + .withQueryParam("client_id", ClientId) + .withQueryParam("client_secret", ClientSecret) + .withQueryParam("code", code) + .withQueryParam("redirect_uri", RedirectUri) + .withQueryParam("state", state) + + client + .expect[AccessTokenResponse](Request[F](uri = uri))(jsonOf[F, AccessTokenResponse]) + .map(_.access_token) + } + + def userData(accessToken: String): F[String] = { + val request = Request[F](uri = uri"https://api.github.com/user") + .putHeaders("Authorization" -> s"token $accessToken") + + client.expect[String](request) + } +} diff --git a/examples/src/main/scala/com/example/http4s/ssl.scala b/examples/src/main/scala/com/example/http4s/ssl.scala new file mode 100644 index 000000000..e195e2993 --- /dev/null +++ b/examples/src/main/scala/com/example/http4s/ssl.scala @@ -0,0 +1,89 @@ +/* + * Copyright 2014 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.example.http4s + +import cats.effect.Sync +import cats.syntax.all._ +import org.http4s.HttpApp +import org.http4s.Uri.Authority +import org.http4s.Uri.RegName +import org.http4s.Uri.Scheme +import org.http4s.dsl.Http4sDsl +import org.http4s.headers.Host +import org.http4s.headers.Location +import org.http4s.server.SSLKeyStoreSupport.StoreInfo + +import java.nio.file.Paths +import java.security.KeyStore +import java.security.Security +import javax.net.ssl.KeyManagerFactory +import javax.net.ssl.SSLContext + +object ssl { + val keystorePassword: String = "password" + val keyManagerPassword: String = "secure" + + val keystorePath: String = Paths.get("../server.jks").toAbsolutePath.toString + + val storeInfo: StoreInfo = StoreInfo(keystorePath, keystorePassword) + + def loadContextFromClasspath[F[_]](keystorePassword: String, keyManagerPass: String)(implicit + F: Sync[F] + ): F[SSLContext] = + F.delay { + val ksStream = this.getClass.getResourceAsStream("/server.jks") + val ks = KeyStore.getInstance("JKS") + ks.load(ksStream, keystorePassword.toCharArray) + ksStream.close() + + val kmf = KeyManagerFactory.getInstance( + Option(Security.getProperty("ssl.KeyManagerFactory.algorithm")) + .getOrElse(KeyManagerFactory.getDefaultAlgorithm) + ) + + kmf.init(ks, keyManagerPass.toCharArray) + + val context = SSLContext.getInstance("TLS") + context.init(kmf.getKeyManagers, null, null) + + context + } + + def redirectApp[F[_]: Sync](securePort: Int): HttpApp[F] = { + val dsl = new Http4sDsl[F] {} + import dsl._ + + HttpApp[F] { request => + request.headers.get[Host] match { + case Some(Host(host @ _, _)) => + val baseUri = request.uri.copy( + scheme = Scheme.https.some, + authority = Some( + Authority( + userInfo = request.uri.authority.flatMap(_.userInfo), + host = RegName(host), + port = securePort.some + ) + ) + ) + MovedPermanently(Location(baseUri.withPath(request.uri.path))) + case _ => + BadRequest() + } + } + } +}