diff --git a/io/js/src/main/scala/fs2/io/internal/facade/events.scala b/io/js/src/main/scala/fs2/io/internal/facade/events.scala index 8cdb1041db..4e9c0fa0aa 100644 --- a/io/js/src/main/scala/fs2/io/internal/facade/events.scala +++ b/io/js/src/main/scala/fs2/io/internal/facade/events.scala @@ -58,6 +58,21 @@ private[io] trait EventEmitter extends js.Object { } private[io] object EventEmitter { + final class Scope private[EventEmitter] { + private[EventEmitter] val cleanup = new js.Array[js.Function0[Unit]] + } + + def openScope[F[_]](implicit F: Sync[F]): Resource[F, Scope] = + Resource.make(F.delay(new Scope)) { scope => + F.delay { + scope.cleanup.foreach { task => + try + task() + catch { case _: Throwable => () } + } + } + } + implicit class ops(val eventTarget: EventEmitter) extends AnyVal { def registerListener[F[_], E](eventName: String, dispatcher: Dispatcher[F])( @@ -87,5 +102,40 @@ private[io] object EventEmitter { eventTarget.once(eventName, fn) Some(F.delay(eventTarget.removeListener(eventName, fn))) } + + def unsafeRegisterListener[F[_], E](eventName: String, dispatcher: Dispatcher[F], scope: Scope)( + listener: E => F[Unit] + ): Unit = { + val fn: js.Function1[E, Unit] = e => dispatcher.unsafeRunAndForget(listener(e)) + eventTarget.on(eventName, fn) + scope.cleanup.push(() => eventTarget.removeListener(eventName, fn)) + () + } + + def unsafeRegisterOneTimeListener0[F[_], E]( + eventName: String, + dispatcher: Dispatcher[F], + scope: Scope + )( + listener: () => F[Unit] + ): Unit = { + val fn: js.Function0[Unit] = () => dispatcher.unsafeRunAndForget(listener()) + eventTarget.once(eventName, fn) + scope.cleanup.push(() => eventTarget.removeListener(eventName, fn)) + () + } + + def unsafeRegisterOneTimeListener[F[_], E]( + eventName: String, + dispatcher: Dispatcher[F], + scope: Scope + )( + listener: E => F[Unit] + ): Unit = { + val fn: js.Function1[E, Unit] = e => dispatcher.unsafeRunAndForget(listener(e)) + eventTarget.once(eventName, fn) + scope.cleanup.push(() => eventTarget.removeListener(eventName, fn)) + () + } } } diff --git a/io/js/src/main/scala/fs2/io/ioplatform.scala b/io/js/src/main/scala/fs2/io/ioplatform.scala index 011f770764..1e3859c9f5 100644 --- a/io/js/src/main/scala/fs2/io/ioplatform.scala +++ b/io/js/src/main/scala/fs2/io/ioplatform.scala @@ -31,7 +31,6 @@ import cats.effect.std.Queue import cats.effect.syntax.all._ import cats.syntax.all._ import fs2.concurrent.Channel -import fs2.io.internal.MicrotaskExecutor import fs2.io.internal.facade import java.nio.charset.Charset @@ -63,37 +62,42 @@ private[fs2] trait ioplatform { dispatcher <- Dispatcher.sequential[F] channel <- Channel.unbounded[F, Unit].toResource error <- F.deferred[Throwable].toResource - readableResource = for { - readable <- Resource.makeCase(F.delay(thunk)) { - case (readable, Resource.ExitCase.Succeeded) => - F.delay { - if (!readable.readableEnded & destroyIfNotEnded) - readable.destroy() - } - case (readable, Resource.ExitCase.Errored(_)) => - // tempting, but don't propagate the error! - // that would trigger a unhandled Node.js error that circumvents FS2/CE error channels - F.delay(readable.destroy()) - case (readable, Resource.ExitCase.Canceled) => - if (destroyIfCanceled) - F.delay(readable.destroy()) - else - F.unit - } - _ <- readable.registerListener[F, Any]("readable", dispatcher)(_ => channel.send(()).void) - _ <- readable.registerListener[F, Any]("end", dispatcher)(_ => channel.close.void) - _ <- readable.registerListener[F, Any]("close", dispatcher)(_ => channel.close.void) - _ <- readable.registerListener[F, js.Error]("error", dispatcher) { e => - error.complete(js.JavaScriptException(e)).void + scope <- facade.events.EventEmitter.openScope + readable <- Resource.makeCase { + F.delay { + val readable = thunk + + readable.unsafeRegisterListener[F, Any]("readable", dispatcher, scope) { _ => + channel.send(()).void + } + readable.unsafeRegisterListener[F, Any]("end", dispatcher, scope) { _ => + channel.close.void + } + readable.unsafeRegisterListener[F, Any]("close", dispatcher, scope) { _ => + channel.close.void + } + readable.unsafeRegisterListener[F, js.Error]("error", dispatcher, scope) { e => + error.complete(js.JavaScriptException(e)).void + } + + readable } - } yield readable - // Implementation note: why run on the MicrotaskExecutor? - // In many cases creating a `Readable` starts async side-effects (e.g. negotiating TLS handshake or opening a file handle). - // Furthermore, these side-effects will invoke the listeners we register to the `Readable`. - // Therefore, it is critical that the listeners are registered to the `Readable` _before_ these async side-effects occur: - // in other words, before we next yield (cede) to the event loop. Because an arbitrary effect `F` (particularly `IO`) may cede at any time, - // our only recourse is to run the entire creation/listener registration process on the microtask executor. - readable <- readableResource.evalOn(MicrotaskExecutor) + } { + case (readable, Resource.ExitCase.Succeeded) => + F.delay { + if (!readable.readableEnded & destroyIfNotEnded) + readable.destroy() + } + case (readable, Resource.ExitCase.Errored(_)) => + // tempting, but don't propagate the error! + // that would trigger a unhandled Node.js error that circumvents FS2/CE error channels + F.delay(readable.destroy()) + case (readable, Resource.ExitCase.Canceled) => + if (destroyIfCanceled) + F.delay(readable.destroy()) + else + F.unit + } stream = (channel.stream .concurrently(Stream.eval(error.get.flatMap(F.raiseError[Unit]))) >> diff --git a/io/js/src/main/scala/fs2/io/net/tls/TLSContextPlatform.scala b/io/js/src/main/scala/fs2/io/net/tls/TLSContextPlatform.scala index a506e02f58..3868eb2ff8 100644 --- a/io/js/src/main/scala/fs2/io/net/tls/TLSContextPlatform.scala +++ b/io/js/src/main/scala/fs2/io/net/tls/TLSContextPlatform.scala @@ -62,79 +62,75 @@ private[tls] trait TLSContextCompanionPlatform { self: TLSContext.type => clientMode: Boolean, params: TLSParameters, logger: TLSLogger[F] - ): Resource[F, TLSSocket[F]] = (Dispatcher.sequential[F], Dispatcher.parallel[F]) - .flatMapN { (seqDispatcher, parDispatcher) => - if (clientMode) { - Resource.eval(F.deferred[Either[Throwable, Unit]]).flatMap { handshake => - TLSSocket - .forAsync( - socket, - sock => { - val options = params.toTLSConnectOptions(parDispatcher) - options.secureContext = context - if (insecure) - options.rejectUnauthorized = false - options.enableTrace = logger != TLSLogger.Disabled - options.socket = sock - val tlsSock = facade.tls.connect(options) - tlsSock.once( - "secureConnect", - () => seqDispatcher.unsafeRunAndForget(handshake.complete(Either.unit)) - ) - tlsSock.once[js.Error]( - "error", - e => - seqDispatcher.unsafeRunAndForget( - handshake.complete(Left(new js.JavaScriptException(e))) + ): Resource[F, TLSSocket[F]] = + (Dispatcher.sequential[F], Dispatcher.parallel[F], facade.events.EventEmitter.openScope) + .flatMapN { (seqDispatcher, parDispatcher, scope) => + if (clientMode) { + Resource.eval(F.deferred[Either[Throwable, Unit]]).flatMap { handshake => + TLSSocket + .forAsync( + socket, + sock => { + val options = params.toTLSConnectOptions(parDispatcher) + options.secureContext = context + if (insecure) + options.rejectUnauthorized = false + options.enableTrace = logger != TLSLogger.Disabled + options.socket = sock + val tlsSock = facade.tls.connect(options) + tlsSock + .unsafeRegisterOneTimeListener0("secureConnect", seqDispatcher, scope)( + () => handshake.complete(Either.unit).void ) - ) - tlsSock - } - ) - .evalTap(_ => handshake.get.rethrow) - } - } else { - Resource.eval(F.deferred[Either[Throwable, Unit]]).flatMap { verifyError => - TLSSocket - .forAsync( - socket, - sock => { - val options = params.toTLSSocketOptions(parDispatcher) - options.secureContext = context - if (insecure) - options.rejectUnauthorized = false - options.enableTrace = logger != TLSLogger.Disabled - options.isServer = true - val tlsSock = new facade.tls.TLSSocket(sock, options) - tlsSock.once( - "secure", - { () => - val requestCert = options.requestCert.getOrElse(false) - val rejectUnauthorized = options.rejectUnauthorized.getOrElse(true) - val result = - if (requestCert && rejectUnauthorized) - Option(tlsSock.ssl.verifyError()) - .map(e => new JavaScriptSSLException(js.JavaScriptException(e))) - .toLeft(()) - else Either.unit - seqDispatcher.unsafeRunAndForget(verifyError.complete(result)) + tlsSock.unsafeRegisterOneTimeListener[F, js.Error]( + "error", + seqDispatcher, + scope + )(e => handshake.complete(Left(new js.JavaScriptException(e))).void) + tlsSock + } + ) + .evalTap(_ => handshake.get.rethrow) + } + } else { + Resource.eval(F.deferred[Either[Throwable, Unit]]).flatMap { verifyError => + TLSSocket + .forAsync( + socket, + sock => { + val options = params.toTLSSocketOptions(parDispatcher) + options.secureContext = context + if (insecure) + options.rejectUnauthorized = false + options.enableTrace = logger != TLSLogger.Disabled + options.isServer = true + val tlsSock = new facade.tls.TLSSocket(sock, options) + tlsSock.unsafeRegisterOneTimeListener0("secure", seqDispatcher, scope) { + () => + val requestCert = options.requestCert.getOrElse(false) + val rejectUnauthorized = options.rejectUnauthorized.getOrElse(true) + val result = + if (requestCert && rejectUnauthorized) + // side-effect must run in the callback + Option(tlsSock.ssl.verifyError()) + .map(e => new JavaScriptSSLException(js.JavaScriptException(e))) + .toLeft(()) + else Either.unit + verifyError.complete(result).void } - ) - tlsSock.once[js.Error]( - "error", - e => - seqDispatcher.unsafeRunAndForget( - verifyError.complete(Left(new js.JavaScriptException(e))) - ) - ) - tlsSock - } - ) - .evalTap(_ => verifyError.get.rethrow) + tlsSock.unsafeRegisterOneTimeListener[F, js.Error]( + "error", + seqDispatcher, + scope + )(e => verifyError.complete(Left(new js.JavaScriptException(e))).void) + tlsSock + } + ) + .evalTap(_ => verifyError.get.rethrow) + } } } - } - .adaptError { case IOException(ex) => ex } + .adaptError { case IOException(ex) => ex } } def fromSecureContext(context: SecureContext): TLSContext[F] =