Skip to content
196 changes: 113 additions & 83 deletions akka-actor/src/main/scala/akka/io/dns/internal/AsyncDnsResolver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,20 @@ package akka.io.dns.internal
import java.net.{ Inet4Address, Inet6Address, InetAddress, InetSocketAddress }

import scala.collection.immutable
import scala.concurrent.ExecutionContextExecutor
import scala.concurrent.Future
import scala.util.Try
import scala.concurrent.Promise
import scala.util.{ Failure, Success, Try }
import scala.util.control.NonFatal

import akka.actor.{ Actor, ActorLogging, ActorRef, ActorRefFactory }
import akka.actor.{ Actor, ActorLogging, ActorRef, ActorRefFactory, Status }
import akka.annotation.InternalApi
import akka.dispatch.ExecutionContexts
import akka.io.SimpleDnsCache
import akka.io.dns._
import akka.io.dns.CachePolicy.{ Never, Ttl }
import akka.io.dns.DnsProtocol.{ Ip, RequestType, Srv }
import akka.io.dns.internal.DnsClient._
import akka.pattern.{ ask, pipe }
import akka.pattern.ask
import akka.pattern.AskTimeoutException
import akka.util.{ Helpers, Timeout }
import akka.util.PrettyDuration._
Expand All @@ -37,8 +38,6 @@ private[io] final class AsyncDnsResolver(

import AsyncDnsResolver._

implicit val ec: ExecutionContextExecutor = context.dispatcher
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do kind of feel that having an implicit execution context in an actor is a bad sign, by decreasing the cognitive cost of adding a future callback within the actor.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💯 agree


// avoid ever looking up localhost by pre-populating cache
{
val loopback = InetAddress.getLoopbackAddress
Expand Down Expand Up @@ -93,17 +92,28 @@ private[io] final class AsyncDnsResolver(
log.debug("{} cached {}", mode, resolved)
sender() ! resolved
case None =>
resolveWithResolvers(name, mode, resolvers)
.map { resolved =>
if (resolved.records.nonEmpty) {
val minTtl = (positiveCachePolicy +: resolved.records.map(_.ttl)).min
cache.put((name, mode), resolved, minTtl)
} else if (negativeCachePolicy != Never) cache.put((name, mode), resolved, negativeCachePolicy)
log.debug(s"{} resolved {}", mode, resolved)
resolved
val replyTo = sender()
resolveWithResolvers(name, mode, resolvers).onComplete {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Broad idea is to limit the future callbacks inside the actor. Combining the cache updates with piping preserves the sequencing in .map { x => sideEffect(); x }.pipeTo(sender()) and allows use of parasitic for performance without making that the default context.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe pipeTo is better .

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, +100 for parasitic

_ match {
case Success(resolved) =>
if (resolved.records.nonEmpty) {
val minTtl = (positiveCachePolicy +: resolved.records.map(_.ttl)).min
cache.put((name, mode), resolved, minTtl) // thread-safe structure, safe in callback
} else if (negativeCachePolicy != Never) {
cache.put((name, mode), resolved, negativeCachePolicy) // thread-safe structure, safe in callback
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it was going to just be a Map, definitely. In this case, since we're going through the expense of using a thread-safe structure, may as well exercise that :)

This does have the benefit of quickly allowing a successful resolution to immediately update the cache for requests which might be in the mailbox now (without having to do a custom mailbox). Since the only cache.get is right when we receive a DnsProtocol.Resolve, it's not much of a win (it might be a slight win under heavy load to check the cache when making follow-up queries).

}

log.debug("{} resolved {}", mode, resolved)
replyTo ! resolved

case Failure(f) =>
replyTo ! Status.Failure(f)
}
.pipeTo(sender())
}(ExecutionContexts.parasitic)
}

case Internal.Resolve(name, requestType, resolver, promise) =>
resolvePromise(name, requestType, resolver, promise)
}

private def resolveWithResolvers(
Expand All @@ -127,7 +137,7 @@ private[io] final class AsyncDnsResolver(
case Nil =>
Future.failed(ResolveFailedException(s"Failed to resolve $name with nameservers: $nameServers"))
case head :: tail =>
resolveWithSearch(name, requestType, head).recoverWith {
resolveWithSearch(name, requestType, head, settings, self).recoverWith {
case NonFatal(t) =>
t match {
case _: AskTimeoutException =>
Expand All @@ -136,90 +146,39 @@ private[io] final class AsyncDnsResolver(
log.info("Resolve of {} failed. Trying next name server {}", name, t.getMessage)
}
resolveWithResolvers(name, requestType, tail)
}
}(ExecutionContexts.parasitic)
}
}

private def sendQuestion(resolver: ActorRef, message: DnsQuestion): Future[Answer] = {
val result = (resolver ? message).mapTo[Answer]
result.failed.foreach { _ =>
resolver ! DropRequest(message.id)
}
result
}

private def resolveWithSearch(
private def resolvePromise(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do I understand it correctly that the only reason for going back to the actor via Internal.Resolve is that the nextId() is needed here, or is there something else that requires this to be running in the actor?

If that is the case, would it be better to change the requestId counter to an a thread safe counter (AtomicInteger) and stay in the Future land in the companion object methods?

name: String,
requestType: RequestType,
resolver: ActorRef): Future[DnsProtocol.Resolved] = {
if (settings.SearchDomains.nonEmpty) {
val nameWithSearch = settings.SearchDomains.map(sd => name + "." + sd)
// ndots is a heuristic used to try and work out whether the name passed in is a fully qualified domain name,
// or a name relative to one of the search names. The idea is to prevent the cost of doing a lookup that is
// obviously not going to resolve. So, if a host has less than ndots dots in it, then we don't try and resolve it,
// instead, we go directly to the search domains, or at least that's what the man page for resolv.conf says. In
// practice, Linux appears to implement something slightly different, if the name being searched contains less
// than ndots dots, then it should be searched last, rather than first. This means if the heuristic wrongly
// identifies a domain as being relative to the search domains, it will still be looked up if it doesn't resolve
// at any of the search domains, albeit with the latency of having to have done all the searches first.
val toResolve = if (name.count(_ == '.') >= settings.NDots) {
name :: nameWithSearch
} else {
nameWithSearch :+ name
}
resolveFirst(toResolve, requestType, resolver)
} else {
resolve(name, requestType, resolver)
}
}

private def resolveFirst(
searchNames: List[String],
requestType: RequestType,
resolver: ActorRef): Future[DnsProtocol.Resolved] = {
searchNames match {
case searchName :: Nil =>
resolve(searchName, requestType, resolver)
case searchName :: remaining =>
resolve(searchName, requestType, resolver).flatMap { resolved =>
if (resolved.records.isEmpty) resolveFirst(remaining, requestType, resolver)
else Future.successful(resolved)
}
case Nil =>
// This can't happen
Future.failed(new IllegalStateException("Failed to 'resolveFirst': 'searchNames' must not be empty"))
}
}

private def resolve(name: String, requestType: RequestType, resolver: ActorRef): Future[DnsProtocol.Resolved] = {
resolver: ActorRef,
promise: Promise[DnsProtocol.Resolved]): Unit = {
log.debug("Attempting to resolve {} with {}", name, resolver)
val caseFoldedName = Helpers.toRootLowerCase(name)
requestType match {
case Ip(ipv4, ipv6) =>
val ipv4Recs: Future[Answer] =
if (ipv4)
sendQuestion(resolver, Question4(nextId(), caseFoldedName))
else
Empty
val ipv4Recs =
if (ipv4) sendQuestion(resolver, Question4(nextId(), caseFoldedName))
else Empty

val ipv6Recs =
if (ipv6)
sendQuestion(resolver, Question6(nextId(), caseFoldedName))
else
Empty
if (ipv6) sendQuestion(resolver, Question6(nextId(), caseFoldedName))
else Empty

for {
ipv4 <- ipv4Recs
ipv6 <- ipv6Recs
} yield DnsProtocol.Resolved(name, ipv4.rrs ++ ipv6.rrs, ipv4.additionalRecs ++ ipv6.additionalRecs)
promise.completeWith(ipv4Recs.flatMap { v4 =>
ipv6Recs.map { v6 =>
DnsProtocol.Resolved(name, v4.rrs ++ v6.rrs, v4.additionalRecs ++ v6.additionalRecs)
}(ExecutionContexts.parasitic)
}(ExecutionContexts.parasitic))
Copy link
Contributor Author

@leviramsey leviramsey Apr 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, personal rule to never implicit val ec = ExecutionContexts.parasitic...


case Srv =>
sendQuestion(resolver, SrvQuestion(nextId(), caseFoldedName)).map(answer => {
promise.completeWith(sendQuestion(resolver, SrvQuestion(nextId(), caseFoldedName)).map { answer =>
DnsProtocol.Resolved(name, answer.rrs, answer.additionalRecs)
})
}(ExecutionContexts.parasitic))
}
}

}

/**
Expand Down Expand Up @@ -247,4 +206,75 @@ private[akka] object AsyncDnsResolver {
Future.successful(Answer(-1, immutable.Seq.empty[ResourceRecord], immutable.Seq.empty[ResourceRecord]))

case class ResolveFailedException(msg: String) extends Exception(msg)

// Internal commands to avoid calling private methods in future callbacks
private[akka] object Internal {
case class Resolve(
name: String,
requestType: RequestType,
resolver: ActorRef,
promise: Promise[DnsProtocol.Resolved])
}

// These methods are not in the class so that we can be sure they don't access actor state
private def resolveWithSearch(
name: String,
requestType: RequestType,
resolver: ActorRef,
settings: DnsSettings,
self: ActorRef): Future[DnsProtocol.Resolved] = {
if (settings.SearchDomains.nonEmpty) {
val nameWithSearch = settings.SearchDomains.map(sd => name + "." + sd)
// ndots is a heuristic used to try and work out whether the name passed in is a fully qualified domain name,
// or a name relative to one of the search names. The idea is to prevent the cost of doing a lookup that is
// obviously not going to resolve. So, if a host has less than ndots dots in it, then we don't try and resolve it,
// instead, we go directly to the search domains, or at least that's what the man page for resolv.conf says. In
// practice, Linux appears to implement something slightly different, if the name being searched contains less
// than ndots dots, then it should be searched last, rather than first. This means if the heuristic wrongly
// identifies a domain as being relative to the search domains, it will still be looked up if it doesn't resolve
// at any of the search domains, albeit with the latency of having to have done all the searches first.
val toResolve = if (name.count(_ == '.') >= settings.NDots) {
name :: nameWithSearch
} else {
nameWithSearch :+ name
}
resolveFirst(toResolve, requestType, resolver, self)
} else {
val selfCmd = Internal.Resolve(name, requestType, resolver, Promise())
self ! selfCmd
selfCmd.promise.future
}
}

private def resolveFirst(
searchNames: List[String],
requestType: RequestType,
resolver: ActorRef,
self: ActorRef): Future[DnsProtocol.Resolved] =
searchNames.headOption match {
case Some(searchName) =>
val selfCmd = Internal.Resolve(searchName, requestType, resolver, Promise())
self ! selfCmd

if (searchNames.tail.isEmpty) {
selfCmd.promise.future
} else {
selfCmd.promise.future.flatMap { resolved =>
if (resolved.records.isEmpty) resolveFirst(searchNames.tail, requestType, resolver, self)
else selfCmd.promise.future
}(ExecutionContexts.parasitic)
}

case None =>
// can't happen
Future.failed(new IllegalStateException("Failed to 'resolveFirst': 'searchNames' must not be empty"))
}

private def sendQuestion(resolver: ActorRef, message: DnsQuestion)(implicit timeout: Timeout): Future[Answer] = {
val result = (resolver ? message).mapTo[Answer]
result.failed.foreach { _ =>
resolver ! DropRequest(message.id)
}(ExecutionContexts.parasitic)
result
}
}