From 972d3344b62ea2e15f089d8fc641073ec803ec8e Mon Sep 17 00:00:00 2001
From: patrick-schultz <pschultz@broadinstitute.org>
Date: Tue, 17 Dec 2024 11:16:26 -0500
Subject: [PATCH] ir gen mvp

---
 hail/build.mill                        |  17 +++
 hail/hail/ir-gen/src/Main.scala        | 154 +++++++++++++++++++++++++
 hail/hail/src/is/hail/expr/ir/IR.scala |  88 +++++++-------
 3 files changed, 213 insertions(+), 46 deletions(-)
 create mode 100644 hail/hail/ir-gen/src/Main.scala

diff --git a/hail/build.mill b/hail/build.mill
index 6a14da4d4b3..44cc74e67b3 100644
--- a/hail/build.mill
+++ b/hail/build.mill
@@ -175,6 +175,10 @@ object hail extends HailModule { outer =>
     buildInfo(),
   )
 
+  override def generatedSources: T[Seq[PathRef]] = Task {
+    Seq(`ir-gen`.generate())
+  }
+
   override def unmanagedClasspath: T[Agg[PathRef]] =
     Agg(shadedazure.assembly())
 
@@ -246,6 +250,19 @@ object hail extends HailModule { outer =>
     PathRef(T.dest)
   }
 
+  object `ir-gen` extends HailModule {
+    def ivyDeps = Agg(
+      ivy"com.lihaoyi::mainargs:0.6.2",
+      ivy"com.lihaoyi::os-lib:0.10.7",
+      ivy"com.lihaoyi::sourcecode:0.4.2",
+    )
+
+    def generate: T[PathRef] = Task {
+      runForkedTask(finalMainClass, Task.Anon { Args("--path", T.dest) })()
+      PathRef(T.dest)
+    }
+  }
+
   object memory extends JavaModule { // with CrossValue {
     override def zincIncrementalCompilation: T[Boolean] = false
 
diff --git a/hail/hail/ir-gen/src/Main.scala b/hail/hail/ir-gen/src/Main.scala
new file mode 100644
index 00000000000..c85e162f8f8
--- /dev/null
+++ b/hail/hail/ir-gen/src/Main.scala
@@ -0,0 +1,154 @@
+import mainargs.{ParserForMethods, main}
+
+sealed abstract class Trait(val name: String)
+
+object Trivial extends Trait("TrivialIR")
+object BaseRef extends Trait("BaseRef")
+
+case class NChildren(static: Int = 0, dynamic: String = "") {
+  def +(other: NChildren): NChildren = NChildren(
+    static = static + other.static,
+    dynamic = if (dynamic.isEmpty) other.dynamic else s"$dynamic + ${other.dynamic}",
+  )
+}
+
+sealed abstract class AttOrChild {
+  val name: String
+  def generateDeclaration: String
+  def constraints: Seq[String] = Seq.empty
+  def nChildren: NChildren = NChildren()
+}
+
+final case class Att(name: String, typ: String, isVar: Boolean = false) extends AttOrChild {
+  override def generateDeclaration: String = s"${if (isVar) "var " else ""}$name: $typ"
+}
+
+final case class Child(name: String) extends AttOrChild {
+  override def generateDeclaration: String = s"$name: IR"
+  override def nChildren: NChildren = NChildren(static = 1)
+}
+
+final case class ChildPlus(name: String) extends AttOrChild {
+  override def generateDeclaration: String = s"$name: IndexedSeq[IR]"
+  override def constraints: Seq[String] = Seq(s"$name.nonEmpty")
+  override def nChildren: NChildren = NChildren(dynamic = "name.size")
+}
+
+final case class ChildStar(name: String) extends AttOrChild {
+  override def generateDeclaration: String = s"$name: IndexedSeq[IR]"
+  override def nChildren: NChildren = NChildren(dynamic = "name.size")
+}
+
+case class IR(
+  name: String,
+  attsAndChildren: Seq[AttOrChild],
+  traits: Seq[Trait] = Seq.empty,
+  extraMethods: Seq[String] = Seq.empty,
+  applyMethods: Seq[String] = Seq.empty,
+  docstring: String = "",
+) {
+  def withTraits(newTraits: Trait*): IR = copy(traits = traits ++ newTraits)
+  def withMethod(methodDef: String): IR = copy(extraMethods = extraMethods :+ methodDef)
+  def withApply(methodDef: String): IR = copy(applyMethods = applyMethods :+ methodDef)
+  def withDocstring(docstring: String): IR = copy(docstring = docstring)
+
+  private def nChildren: NChildren = attsAndChildren.foldLeft(NChildren())(_ + _.nChildren)
+
+  private def children: String = {
+    val tmp = attsAndChildren.flatMap {
+      case _: Att => None
+      case c: Child => Some(s"FastSeq(${c.name})")
+      case cs: ChildPlus => Some(cs.name)
+      case cs: ChildStar => Some(cs.name)
+    }
+    if (tmp.isEmpty) "FastSeq.empty" else tmp.mkString(" ++ ")
+  }
+
+  private def paramList = s"$name(${attsAndChildren.map(_.generateDeclaration).mkString(", ")})"
+
+  private def classDecl =
+    s"final case class $paramList extends IR" + traits.map(" with " + _.name).mkString
+
+  private def classBody = {
+    val extraMethods =
+      this.extraMethods :+ s"override lazy val childrenSeq: IndexedSeq[IR] = $children"
+    val constraints = attsAndChildren.flatMap(_.constraints)
+    if (constraints.nonEmpty || extraMethods.nonEmpty) {
+      (
+        " {" +
+          (if (constraints.nonEmpty)
+             constraints.map(c => s"  require($c)").mkString("\n", "\n", "\n")
+           else "")
+          + (
+            if (extraMethods.nonEmpty)
+              extraMethods.map("  " + _).mkString("\n", "\n", "\n")
+            else ""
+          )
+          + "}"
+      )
+    } else ""
+  }
+
+  private def classDef =
+    (if (docstring.nonEmpty) s"\n// $docstring\n" else "") + classDecl + classBody
+
+  private def companionBody = applyMethods.map("  " + _).mkString("\n")
+
+  private def companionDef =
+    if (companionBody.isEmpty) "" else s"object $name {\n$companionBody\n}\n"
+
+  def generateDef: String = companionDef + classDef + "\n"
+}
+
+object Main {
+  def node(name: String, attsAndChildren: AttOrChild*): IR = IR(name, attsAndChildren)
+
+  def allNodes: Seq[IR] = {
+    val r = Seq.newBuilder[IR]
+
+    r += node("I32", Att("x", "Int")).withTraits(Trivial)
+    r += node("I64", Att("x", "Long")).withTraits(Trivial)
+    r += node("F32", Att("x", "Float")).withTraits(Trivial)
+    r += node("F64", Att("x", "Double")).withTraits(Trivial)
+    r += node("Str", Att("x", "String")).withTraits(Trivial)
+      .withMethod(
+        "override def toString(): String = s\"\"\"Str(\"${StringEscapeUtils.escapeString(x)}\")\"\"\""
+      )
+    r += node("True").withTraits(Trivial)
+    r += node("False").withTraits(Trivial)
+    r += node("Void").withTraits(Trivial)
+    r += node("NA", Att("_typ", "Type")).withTraits(Trivial)
+    r += node("UUID4", Att("id", "String"))
+      .withDocstring(
+        "WARNING! This node can only be used when trying to append a one-off, "
+          + "random string that will not be reused elsewhere in the pipeline. "
+          + "Any other uses will need to write and then read again; this node is non-deterministic "
+          + "and will not e.g. exhibit the correct semantics when self-joining on streams."
+      )
+      .withApply("def apply(): UUID4 = UUID4(genUID())")
+    r += node("Cast", Child("v"), Att("_typ", "Type"))
+    r += node("CastRename", Child("v"), Att("_typ", "Type"))
+    r += node("IsNA", Child("value"))
+    r += node("Coalesce", ChildPlus("values"))
+    r += node("Consume", Child("value"))
+    r += node("If", Child("cond"), Child("cnsq"), Child("altr"))
+    r += node("Switch", Child("x"), Child("default"), ChildStar("cases"))
+      .withMethod("override lazy val size: Int = 2 + cases.length")
+    r += IR("Ref", Seq(Att("name", "Name"), Att("_typ", "Type", isVar = true)))
+      .withTraits(BaseRef)
+
+    r.result()
+  }
+
+  @main
+  def main(path: String) = {
+    val pack = "package is.hail.expr.ir"
+    val imports = Seq("is.hail.types.virtual.Type", "is.hail.utils.{FastSeq, StringEscapeUtils}")
+    val gen = pack + "\n\n" + imports.map(i => s"import $i").mkString("\n") + "\n\n" + allNodes.map(
+      _.generateDef
+    ).mkString("\n")
+    os.write(os.Path(path) / "IR_gen.scala", gen)
+  }
+
+  def main(args: Array[String]): Unit = ParserForMethods(this).runOrExit(args)
+}
diff --git a/hail/hail/src/is/hail/expr/ir/IR.scala b/hail/hail/src/is/hail/expr/ir/IR.scala
index 5f432ab2dd5..65417b45a70 100644
--- a/hail/hail/src/is/hail/expr/ir/IR.scala
+++ b/hail/hail/src/is/hail/expr/ir/IR.scala
@@ -27,20 +27,22 @@ import java.io.OutputStream
 import org.json4s.{DefaultFormats, Extraction, Formats, JValue, ShortTypeHints}
 import org.json4s.JsonAST.{JNothing, JString}
 
-sealed trait IR extends BaseIR {
+trait IR extends BaseIR {
   private var _typ: Type = null
 
-  def typ: Type = {
-    if (_typ == null)
+  override def typ: Type = {
+    if (_typ == null) {
       try
         _typ = InferType(this)
       catch {
         case e: Throwable => throw new RuntimeException(s"typ: inference failure:", e)
       }
+      assert(_typ != null)
+    }
     _typ
   }
 
-  protected lazy val childrenSeq: IndexedSeq[BaseIR] =
+  override protected lazy val childrenSeq: IndexedSeq[BaseIR] =
     Children(this)
 
   override protected def copyWithNewChildren(newChildren: IndexedSeq[BaseIR]): IR =
@@ -72,12 +74,12 @@ sealed trait IR extends BaseIR {
   def unwrap: IR = _unwrap(this)
 }
 
-sealed trait TypedIR[T <: Type] extends IR {
+trait TypedIR[T <: Type] extends IR {
   override def typ: T = tcoerce[T](super.typ)
 }
 
 // Mark Refs and constants as IRs that are safe to duplicate
-sealed trait TrivialIR extends IR
+trait TrivialIR extends IR
 
 object Literal {
   def coerce(t: Type, x: Any): IR = {
@@ -146,48 +148,47 @@ class WrappedByteArrays(val ba: Array[Array[Byte]]) {
   }
 }
 
-final case class I32(x: Int) extends IR with TrivialIR
-final case class I64(x: Long) extends IR with TrivialIR
-final case class F32(x: Float) extends IR with TrivialIR
-final case class F64(x: Double) extends IR with TrivialIR
+//final case class I32(x: Int) extends IR with TrivialIR
+//final case class I64(x: Long) extends IR with TrivialIR
+//final case class F32(x: Float) extends IR with TrivialIR
+//final case class F64(x: Double) extends IR with TrivialIR
 
-final case class Str(x: String) extends IR with TrivialIR {
-  override def toString(): String = s"""Str("${StringEscapeUtils.escapeString(x)}")"""
-}
+//final case class Str(x: String) extends IR with TrivialIR {
+//  override def toString(): String = s"""Str("${StringEscapeUtils.escapeString(x)}")"""
+//}
 
-final case class True() extends IR with TrivialIR
-final case class False() extends IR with TrivialIR
-final case class Void() extends IR with TrivialIR
+//final case class True() extends IR with TrivialIR
+//final case class False() extends IR with TrivialIR
+//final case class Void() extends IR with TrivialIR
 
-object UUID4 {
-  def apply(): UUID4 = UUID4(genUID())
-}
+//object UUID4 {
+//  def apply(): UUID4 = UUID4(genUID())
+//}
+//
+//// WARNING! This node can only be used when trying to append a one-off,
+//// random string that will not be reused elsewhere in the pipeline.
+//// Any other uses will need to write and then read again; this node is
+//// non-deterministic and will not e.g. exhibit the correct semantics when
+//// self-joining on streams.
+//final case class UUID4(id: String) extends IR
 
-// WARNING! This node can only be used when trying to append a one-off,
-// random string that will not be reused elsewhere in the pipeline.
-// Any other uses will need to write and then read again; this node is
-// non-deterministic and will not e.g. exhibit the correct semantics when
-// self-joining on streams.
-final case class UUID4(id: String) extends IR
+//final case class Cast(v: IR, _typ: Type) extends IR
+//final case class CastRename(v: IR, _typ: Type) extends IR
 
-final case class Cast(v: IR, _typ: Type) extends IR
-final case class CastRename(v: IR, _typ: Type) extends IR
+//final case class NA(_typ: Type) extends IR with TrivialIR
+//final case class IsNA(value: IR) extends IR
 
-final case class NA(_typ: Type) extends IR with TrivialIR
-final case class IsNA(value: IR) extends IR
+//final case class Coalesce(values: Seq[IR]) extends IR {
+//  require(values.nonEmpty)
+//}
 
-final case class Coalesce(values: Seq[IR]) extends IR {
-  require(values.nonEmpty)
-}
+//final case class Consume(value: IR) extends IR
 
-final case class Consume(value: IR) extends IR
+//final case class If(cond: IR, cnsq: IR, altr: IR) extends IR
 
-final case class If(cond: IR, cnsq: IR, altr: IR) extends IR
-
-final case class Switch(x: IR, default: IR, cases: IndexedSeq[IR]) extends IR {
-  override lazy val size: Int =
-    2 + cases.length
-}
+//final case class Switch(x: IR, default: IR, cases: IndexedSeq[IR]) extends IR {
+//  override lazy val size: Int = 2 + cases.length
+//}
 
 object AggLet {
   def apply(name: Name, value: IR, body: IR, isScan: Boolean): IR = {
@@ -238,17 +239,12 @@ object Block {
   }
 }
 
-sealed abstract class BaseRef extends IR with TrivialIR {
+trait BaseRef extends IR with TrivialIR {
   def name: Name
   def _typ: Type
 }
 
-final case class Ref(name: Name, var _typ: Type) extends BaseRef {
-  override def typ: Type = {
-    assert(_typ != null)
-    _typ
-  }
-}
+//final case class Ref(name: Name, var _typ: Type) extends BaseRef
 
 // Recur can't exist outside of loop
 // Loops can be nested, but we can't call outer loops in terms of inner loops so there can only be one loop "active" in a given context