From 26ae121a035c6b3d3e85f34c318bce41bc2175f3 Mon Sep 17 00:00:00 2001 From: patrick-schultz Date: Tue, 17 Dec 2024 11:16:26 -0500 Subject: [PATCH] ir gen mvp --- hail/build.mill | 17 +++ hail/modules/ir-gen/src/Main.scala | 135 ++++++++++++++++++ .../src/main/scala/is/hail/expr/ir/IR.scala | 71 +++++---- 3 files changed, 187 insertions(+), 36 deletions(-) create mode 100644 hail/modules/ir-gen/src/Main.scala diff --git a/hail/build.mill b/hail/build.mill index b8867ed29a6f..5bb5e79edfb6 100644 --- a/hail/build.mill +++ b/hail/build.mill @@ -179,6 +179,10 @@ object `package` extends RootModule with SbtModule with HailScalaModule { outer buildInfo(), ) + override def generatedSources: T[Seq[PathRef]] = Task { + Seq(`ir-gen`.generate()) + } + override def unmanagedClasspath: T[Agg[PathRef]] = Agg(shadedazure.assembly()) @@ -250,6 +254,19 @@ object `package` extends RootModule with SbtModule with HailScalaModule { outer PathRef(T.dest) } + object `ir-gen` extends HailScalaModule { + def ivyDeps = Agg( + ivy"com.lihaoyi::mainargs:0.6.2", + ivy"com.lihaoyi::os-lib:0.10.7", + ivy"com.lihaoyi::sourcecode:0.4.2", + ) + + def generate: T[PathRef] = Task { + runForkedTask(finalMainClass, Task.Anon { Args("--path", T.dest) })() + PathRef(T.dest) + } + } + object memory extends JavaModule { // with CrossValue { override def zincIncrementalCompilation: T[Boolean] = false diff --git a/hail/modules/ir-gen/src/Main.scala b/hail/modules/ir-gen/src/Main.scala new file mode 100644 index 000000000000..a38918391981 --- /dev/null +++ b/hail/modules/ir-gen/src/Main.scala @@ -0,0 +1,135 @@ +import mainargs.{ParserForMethods, main} + +sealed abstract class Trait(val name: String) + +object Trivial extends Trait("TrivialIR") + +case class NChildren(static: Int = 0, dynamic: String = "") { + def +(other: NChildren): NChildren = NChildren( + static = static + other.static, + dynamic = if (dynamic.isEmpty) other.dynamic else s"$dynamic + ${other.dynamic}" + ) +} + +sealed abstract class AttOrChild { + val name: String + def generateDeclaration: String + def constraints: Seq[String] = Seq.empty + def nChildren: NChildren = NChildren() +} + +final case class Att(name: String, typ: String) extends AttOrChild { + override def generateDeclaration: String = s"$name: $typ" +} + +final case class Child(name: String) extends AttOrChild { + override def generateDeclaration: String = s"$name: IR" + override def nChildren: NChildren = NChildren(static = 1) +} + +final case class ChildPlus(name: String) extends AttOrChild { + override def generateDeclaration: String = s"$name: Seq[IR]" + override def constraints: Seq[String] = Seq(s"$name.nonEmpty") + override def nChildren: NChildren = NChildren(dynamic = "name.size") +} + +final case class ChildStar(name: String) extends AttOrChild { + override def generateDeclaration: String = s"$name: Seq[IR]" + override def nChildren: NChildren = NChildren(dynamic = "name.size") +} + +case class IR( + name: String, + attsAndChildren: Seq[AttOrChild], + traits: Seq[Trait] = Seq.empty, + extraMethods: Seq[String] = Seq.empty, + applyMethods: Seq[String] = Seq.empty, + docstring: String = "", +) { + def withTraits(newTraits: Trait*): IR = copy(traits = traits ++ newTraits) + def withMethod(methodDef: String): IR = copy(extraMethods = extraMethods :+ methodDef) + def withApply(methodDef: String): IR = copy(applyMethods = applyMethods :+ methodDef) + def withDocstring(docstring: String): IR = copy(docstring = docstring) + + private def nChildren: NChildren = attsAndChildren.foldLeft(NChildren())(_ + _.nChildren) + + private def paramList = s"$name(${attsAndChildren.map(_.generateDeclaration).mkString(", ")})" + + private def classDecl = + s"final case class $paramList extends IR" + traits.map(" with " + _.name).mkString + + private def classBody = { + val constraints = attsAndChildren.flatMap(_.constraints) + if (constraints.nonEmpty || extraMethods.nonEmpty) { + ( + " {" + + (if (constraints.nonEmpty) + constraints.map(c => s" require($c)").mkString("\n", "\n", "\n") + else "") + + ( + if (extraMethods.nonEmpty) + extraMethods.map(" " + _).mkString("\n", "\n", "\n") + else "" + ) + + "}" + ) + } else "" + } + + private def classDef = + (if (docstring.nonEmpty) s"\n// $docstring\n" else "") + classDecl + classBody + + private def companionBody = applyMethods.map(" " + _).mkString("\n") + + private def companionDef = + if (companionBody.isEmpty) "" else s"object $name {\n$companionBody\n}\n" + + def generateDef: String = companionDef + classDef + "\n" +} + +object Main { + def node(name: String, attsAndChildren: AttOrChild*): IR = IR(name, attsAndChildren) + + def allNodes: Seq[IR] = { + val r = Seq.newBuilder[IR] + + r += node("I32", Att("x", "Int")).withTraits(Trivial) + r += node("I64", Att("x", "Long")).withTraits(Trivial) + r += node("F32", Att("x", "Float")).withTraits(Trivial) + r += node("F64", Att("x", "Double")).withTraits(Trivial) + r += node("Str", Att("x", "String")).withTraits(Trivial) + .withMethod( + "override def toString(): String = s\"\"\"Str(\"${StringEscapeUtils.escapeString(x)}\")\"\"\"" + ) + r += node("True").withTraits(Trivial) + r += node("False").withTraits(Trivial) + r += node("Void").withTraits(Trivial) + r += node("NA", Att("_typ", "Type")).withTraits(Trivial) + r += node("UUID4", Att("id", "String")) + .withDocstring( + "WARNING! This node can only be used when trying to append a one-off, " + + "random string that will not be reused elsewhere in the pipeline. " + + "Any other uses will need to write and then read again; this node is non-deterministic " + + "and will not e.g. exhibit the correct semantics when self-joining on streams." + ) + .withApply("def apply(): UUID4 = UUID4(genUID())") + r += node("Cast", Child("v"), Att("_typ", "Type")) + r += node("CastRename", Child("v"), Att("_typ", "Type")) + r += node("IsNA", Child("value")) + r += node("Coalesce", ChildPlus("values")) + r += node("Consume", Child("value")) + r += node("If", Child("cond"), Child("cnsq"), Child("altr")) + r += node("Switch", Child("x"), Child("default"), ChildStar("cases")) + .withMethod("override lazy val size: Int = 2 + cases.length") + + r.result() + } + + @main + def main(path: String) = { + val gen = "package is.hail.expr.ir\n\n" + allNodes.map(_.generateDef).mkString("\n") + os.write(os.Path(path) / "IR_gen.scala", gen) + } + + def main(args: Array[String]): Unit = ParserForMethods(this).runOrExit(args) +} diff --git a/hail/modules/src/main/scala/is/hail/expr/ir/IR.scala b/hail/modules/src/main/scala/is/hail/expr/ir/IR.scala index 5f432ab2dd5c..9307b67990ac 100644 --- a/hail/modules/src/main/scala/is/hail/expr/ir/IR.scala +++ b/hail/modules/src/main/scala/is/hail/expr/ir/IR.scala @@ -27,7 +27,7 @@ import java.io.OutputStream import org.json4s.{DefaultFormats, Extraction, Formats, JValue, ShortTypeHints} import org.json4s.JsonAST.{JNothing, JString} -sealed trait IR extends BaseIR { +trait IR extends BaseIR { private var _typ: Type = null def typ: Type = { @@ -72,12 +72,12 @@ sealed trait IR extends BaseIR { def unwrap: IR = _unwrap(this) } -sealed trait TypedIR[T <: Type] extends IR { +trait TypedIR[T <: Type] extends IR { override def typ: T = tcoerce[T](super.typ) } // Mark Refs and constants as IRs that are safe to duplicate -sealed trait TrivialIR extends IR +trait TrivialIR extends IR object Literal { def coerce(t: Type, x: Any): IR = { @@ -146,48 +146,47 @@ class WrappedByteArrays(val ba: Array[Array[Byte]]) { } } -final case class I32(x: Int) extends IR with TrivialIR -final case class I64(x: Long) extends IR with TrivialIR -final case class F32(x: Float) extends IR with TrivialIR -final case class F64(x: Double) extends IR with TrivialIR +//final case class I32(x: Int) extends IR with TrivialIR +//final case class I64(x: Long) extends IR with TrivialIR +//final case class F32(x: Float) extends IR with TrivialIR +//final case class F64(x: Double) extends IR with TrivialIR -final case class Str(x: String) extends IR with TrivialIR { - override def toString(): String = s"""Str("${StringEscapeUtils.escapeString(x)}")""" -} - -final case class True() extends IR with TrivialIR -final case class False() extends IR with TrivialIR -final case class Void() extends IR with TrivialIR +//final case class Str(x: String) extends IR with TrivialIR { +// override def toString(): String = s"""Str("${StringEscapeUtils.escapeString(x)}")""" +//} -object UUID4 { - def apply(): UUID4 = UUID4(genUID()) -} +//final case class True() extends IR with TrivialIR +//final case class False() extends IR with TrivialIR +//final case class Void() extends IR with TrivialIR -// WARNING! This node can only be used when trying to append a one-off, -// random string that will not be reused elsewhere in the pipeline. -// Any other uses will need to write and then read again; this node is -// non-deterministic and will not e.g. exhibit the correct semantics when -// self-joining on streams. -final case class UUID4(id: String) extends IR +//object UUID4 { +// def apply(): UUID4 = UUID4(genUID()) +//} +// +//// WARNING! This node can only be used when trying to append a one-off, +//// random string that will not be reused elsewhere in the pipeline. +//// Any other uses will need to write and then read again; this node is +//// non-deterministic and will not e.g. exhibit the correct semantics when +//// self-joining on streams. +//final case class UUID4(id: String) extends IR -final case class Cast(v: IR, _typ: Type) extends IR -final case class CastRename(v: IR, _typ: Type) extends IR +//final case class Cast(v: IR, _typ: Type) extends IR +//final case class CastRename(v: IR, _typ: Type) extends IR -final case class NA(_typ: Type) extends IR with TrivialIR -final case class IsNA(value: IR) extends IR +//final case class NA(_typ: Type) extends IR with TrivialIR +//final case class IsNA(value: IR) extends IR -final case class Coalesce(values: Seq[IR]) extends IR { - require(values.nonEmpty) -} +//final case class Coalesce(values: Seq[IR]) extends IR { +// require(values.nonEmpty) +//} -final case class Consume(value: IR) extends IR +//final case class Consume(value: IR) extends IR -final case class If(cond: IR, cnsq: IR, altr: IR) extends IR +//final case class If(cond: IR, cnsq: IR, altr: IR) extends IR -final case class Switch(x: IR, default: IR, cases: IndexedSeq[IR]) extends IR { - override lazy val size: Int = - 2 + cases.length -} +//final case class Switch(x: IR, default: IR, cases: IndexedSeq[IR]) extends IR { +// override lazy val size: Int = 2 + cases.length +//} object AggLet { def apply(name: Name, value: IR, body: IR, isScan: Boolean): IR = {