Skip to content

Commit

Permalink
Support type selectors in Resolve (#2997)
Browse files Browse the repository at this point in the history
This PR add support for new selector pattern `_:Type`.

In addition to `_` and `__`, which select arbitrary segments, the
`_:MyType` and `__:MyType` patterns can select modules of the specified
type.

The type is matched by it's name and optionally by it's enclosing types
and packages, separated by a `.` sign. Since this is also used to
separate target path segments, a type selector segment containing a `.`
needs to be enclosed in parenthesis. A full qualified type can be
enforced with the `_root_` package.

Example: Find all test jars

```sh
> mill resolve __:TestModule.jar
> mill resolve "(__:scalalib.TestModule).jar"
> mill resolve "(__:mill.scalalib.TestModule).jar"
> mill resolve "(__:_root_.mill.scalalib.TestModule).jar"
```

If a `^` or `!` is preceding the type pattern, it only matches segments
not an instance of that specified type. Please note that in some shells
like `bash`, you need to mask the `!` character.

Example: Find all jars not in test modules

```sh
> mill resolve __:^TestModule.jar
```

You can also provide more than one type pattern, separated with `:`. 

Example: Find all `JavaModule`s which are not `ScalaModule`s or
`TestModule`s:

```sh
> mill resolve "__:JavaModule:^ScalaModule:^TestModule.jar"
```

Remarks:

* Kudos to @lihaoyi who refactored the resolver in
#2511 and made this PR possible.
I tried to implement it multiple times before, and ever got bitten by
the old gnarly resolver code.
* It's currently not possible to match task/target types. It might be
possible, but due to `Task` being a parametrized type, it might not be
as easy to implement and use.

Fix #1550

Pull request: #2997
  • Loading branch information
lefou committed Feb 4, 2024
1 parent 5e71e4d commit 8d315d7
Show file tree
Hide file tree
Showing 5 changed files with 407 additions and 21 deletions.
19 changes: 16 additions & 3 deletions main/resolve/src/mill/resolve/ParseArgs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,23 @@ object ParseArgs {
}

private def selector[_p: P]: P[(Option[Segments], Segments)] = {
def ident2 = P(CharsWhileIn("a-zA-Z0-9_\\-.")).!
def segment = P(mill.define.Reflect.ident).map(Segment.Label)
def crossSegment = P("[" ~ ident2.rep(1, sep = ",") ~ "]").map(Segment.Cross)
def wildcard = P("__" | "_")
def label = mill.define.Reflect.ident

def typeQualifier(simple: Boolean) = {
val maxSegments = if (simple) 0 else Int.MaxValue
P(("^" | "!").? ~~ label ~~ ("." ~~ label).rep(max = maxSegments)).!
}

def typePattern(simple: Boolean) = P(wildcard ~~ (":" ~~ typeQualifier(simple)).rep(1)).!

def segment0(simple: Boolean) = P(typePattern(simple) | label).map(Segment.Label)
def segment = P("(" ~ segment0(false) ~ ")" | segment0(true))

def identCross = P(CharsWhileIn("a-zA-Z0-9_\\-.")).!
def crossSegment = P("[" ~ identCross.rep(1, sep = ",") ~ "]").map(Segment.Cross)
def defaultCrossSegment = P("[]").map(_ => Segment.Cross(Seq()))

def simpleQuery = P(segment ~ ("." ~ segment | crossSegment | defaultCrossSegment).rep).map {
case (h, rest) => Segments(h +: rest)
}
Expand Down
27 changes: 21 additions & 6 deletions main/resolve/src/mill/resolve/Resolve.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ object Resolve {
Right(resolved.map(_.segments))
}

private[mill] override def deduplicate(items: List[Segments]) = items.distinct
private[mill] override def deduplicate(items: List[Segments]): List[Segments] = items.distinct
}

object Tasks extends Resolve[NamedTask[Any]] {
Expand Down Expand Up @@ -83,11 +83,11 @@ object Resolve {
)
}

private[mill] override def deduplicate(items: List[NamedTask[Any]]) =
private[mill] override def deduplicate(items: List[NamedTask[Any]]): List[NamedTask[Any]] =
items.distinctBy(_.ctx.segments)
}

private def instantiateTarget(r: Resolved.Target, p: Module) = {
private def instantiateTarget(r: Resolved.Target, p: Module): Either[String, Target[_]] = {
val definition = Reflect
.reflect(p.getClass, classOf[Target[_]], _ == r.segments.parts.last, true)
.head
Expand Down Expand Up @@ -230,11 +230,23 @@ trait Resolve[T] {
): Either[String, Seq[T]] = {
val rootResolved = ResolveCore.Resolved.Module(Segments(), rootModule.getClass)
val resolved =
ResolveCore.resolve(rootModule, sel.value.toList, rootResolved, Segments()) match {
ResolveCore.resolve(
rootModule = rootModule,
remainingQuery = sel.value.toList,
current = rootResolved,
querySoFar = Segments()
) match {
case ResolveCore.Success(value) => Right(value)
case ResolveCore.NotFound(segments, found, next, possibleNexts) =>
val allPossibleNames = rootModule.millDiscover.value.values.flatMap(_._1).toSet
Left(ResolveNotFoundHandler(sel, segments, found, next, possibleNexts, allPossibleNames))
Left(ResolveNotFoundHandler(
selector = sel,
segments = segments,
found = found,
next = next,
possibleNexts = possibleNexts,
allPossibleNames = allPossibleNames
))
case ResolveCore.Error(value) => Left(value)
}

Expand All @@ -245,7 +257,10 @@ trait Resolve[T] {

private[mill] def deduplicate(items: List[T]): List[T] = items

private[mill] def resolveRootModule(rootModule: BaseModule, scopedSel: Option[Segments]) = {
private[mill] def resolveRootModule(
rootModule: BaseModule,
scopedSel: Option[Segments]
): Either[String, BaseModule] = {
scopedSel match {
case None => Right(rootModule)
case Some(scoping) =>
Expand Down
124 changes: 112 additions & 12 deletions main/resolve/src/mill/resolve/ResolveCore.scala
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,37 @@ private object ResolveCore {
case "__" =>
val self = Seq(Resolved.Module(m.segments, m.cls))
val transitiveOrErr =
resolveTransitiveChildren(rootModule, m.cls, None, current.segments)
resolveTransitiveChildren(rootModule, m.cls, None, current.segments, Nil)

transitiveOrErr.map(transitive => self ++ transitive)

case "_" =>
resolveDirectChildren(rootModule, m.cls, None, current.segments)

case pattern if pattern.startsWith("__:") =>
val typePattern = pattern.split(":").drop(1)
val self = Seq(Resolved.Module(m.segments, m.cls))

val transitiveOrErr = resolveTransitiveChildren(
rootModule,
m.cls,
None,
current.segments,
typePattern
)

transitiveOrErr.map(transitive => self ++ transitive)

case pattern if pattern.startsWith("_:") =>
val typePattern = pattern.split(":").drop(1)
resolveDirectChildren(
rootModule,
m.cls,
None,
current.segments,
typePattern
)

case _ =>
resolveDirectChildren(rootModule, m.cls, Some(singleLabel), current.segments)
}
Expand Down Expand Up @@ -186,22 +210,75 @@ private object ResolveCore {
cls: Class[_],
nameOpt: Option[String],
segments: Segments
): Either[String, Set[Resolved]] =
resolveTransitiveChildren(rootModule, cls, nameOpt, segments, Nil)

def resolveTransitiveChildren(
rootModule: Module,
cls: Class[_],
nameOpt: Option[String],
segments: Segments,
typePattern: Seq[String]
): Either[String, Set[Resolved]] = {
for {
direct <- resolveDirectChildren(rootModule, cls, nameOpt, segments)
indirect0 = direct
.collect { case m: Resolved.Module =>
resolveTransitiveChildren(rootModule, m.cls, nameOpt, m.segments)
}
indirect <- EitherOps.sequence(indirect0).map(_.flatten)
} yield direct ++ indirect
val direct = resolveDirectChildren(rootModule, cls, nameOpt, segments, typePattern)
direct.flatMap { direct =>
for {
directTraverse <- resolveDirectChildren(rootModule, cls, nameOpt, segments, Nil)
indirect0 = directTraverse
.collect { case m: Resolved.Module =>
resolveTransitiveChildren(rootModule, m.cls, nameOpt, m.segments, typePattern)
}
indirect <- EitherOps.sequence(indirect0).map(_.flatten)
} yield direct ++ indirect
}
}

private def resolveParents(c: Class[_]): Seq[Class[_]] =
Seq(c) ++
Option(c.getSuperclass).toSeq.flatMap(resolveParents) ++
c.getInterfaces.flatMap(resolveParents)

/**
* Check if the given class matches a given type selector as string
* @param cls
* @param typePattern
* @return
*/
private def classMatchesTypePred(typePattern: Seq[String])(cls: Class[_]): Boolean =
typePattern
.forall { pat =>
val negate = pat.startsWith("^") || pat.startsWith("!")
val clsPat = pat.drop(if (negate) 1 else 0)

// We split full class names by `.` and `$`
// a class matches a type patter, if the type pattern segments match from the right
// to express a full match, use `_root_` as first segment

val typeNames = clsPat.split("[.$]").toSeq.reverse

val parents = resolveParents(cls)
val classNames = parents.flatMap(c =>
("_root_$" + c.getName).split("[.$]").toSeq.reverse.inits.toSeq.filter(_.nonEmpty)
)

val isOfType = classNames.contains(typeNames)
if (negate) !isOfType else isOfType
}

def resolveDirectChildren(
rootModule: Module,
cls: Class[_],
nameOpt: Option[String],
segments: Segments
): Either[String, Set[Resolved]] =
resolveDirectChildren(rootModule, cls, nameOpt, segments, typePattern = Nil)

def resolveDirectChildren(
rootModule: Module,
cls: Class[_],
nameOpt: Option[String],
segments: Segments,
typePattern: Seq[String]
): Either[String, Set[Resolved]] = {

val crossesOrErr = if (classOf[Cross[_]].isAssignableFrom(cls) && nameOpt.isEmpty) {
Expand All @@ -216,15 +293,19 @@ private object ResolveCore {
} else Right(Nil)

crossesOrErr.flatMap { crosses =>
resolveDirectChildren0(rootModule, segments, cls, nameOpt)
val filteredCrosses = crosses.filter { c =>
classMatchesTypePred(typePattern)(c.cls)
}

resolveDirectChildren0(rootModule, segments, cls, nameOpt, typePattern)
.map(
_.map {
case (Resolved.Module(s, cls), _) => Resolved.Module(segments ++ s, cls)
case (Resolved.Target(s), _) => Resolved.Target(segments ++ s)
case (Resolved.Command(s), _) => Resolved.Command(segments ++ s)
}
.toSet
.++(crosses)
.++(filteredCrosses)
)
}
}
Expand All @@ -234,15 +315,25 @@ private object ResolveCore {
segments: Segments,
cls: Class[_],
nameOpt: Option[String]
): Either[String, Seq[(Resolved, Option[Module => Either[String, Module]])]] =
resolveDirectChildren0(rootModule, segments, cls, nameOpt, Nil)

def resolveDirectChildren0(
rootModule: Module,
segments: Segments,
cls: Class[_],
nameOpt: Option[String],
typePattern: Seq[String]
): Either[String, Seq[(Resolved, Option[Module => Either[String, Module]])]] = {
def namePred(n: String) = nameOpt.isEmpty || nameOpt.contains(n)

val modulesOrErr: Either[String, Seq[(Resolved, Option[Module => Either[String, Module]])]] =
val modulesOrErr: Either[String, Seq[(Resolved, Option[Module => Either[String, Module]])]] = {
if (classOf[DynamicModule].isAssignableFrom(cls)) {
instantiateModule(rootModule, segments).map {
case m: DynamicModule =>
m.millModuleDirectChildren
.filter(c => namePred(c.millModuleSegments.parts.last))
.filter(c => classMatchesTypePred(typePattern)(c.getClass))
.map(c =>
(
Resolved.Module(
Expand All @@ -256,6 +347,14 @@ private object ResolveCore {
} else Right {
Reflect
.reflectNestedObjects0[Module](cls, namePred)
.filter {
case (_, member) =>
val memberCls = member match {
case f: java.lang.reflect.Field => f.getType
case f: java.lang.reflect.Method => f.getReturnType
}
classMatchesTypePred(typePattern)(memberCls)
}
.map { case (name, member) =>
Resolved.Module(
Segments.labels(decode(name)),
Expand All @@ -274,6 +373,7 @@ private object ResolveCore {
)
}
}
}

val targets = Reflect
.reflect(cls, classOf[Target[_]], namePred, noParams = true)
Expand Down
Loading

0 comments on commit 8d315d7

Please sign in to comment.