Skip to content

Commit

Permalink
Allow to beta reduce curried function applications in quotes reflect
Browse files Browse the repository at this point in the history
Previously, the curried functions with multiple applications
were not able to be beta-reduced in any way, which was unexpected.
Now we allow reducing any number of top-level function applications
for a curried function. This was also made clearer in the documentation
for the affected (Expr.betaReduce and Term.betaReduce) methods.
  • Loading branch information
jchyb committed Mar 26, 2024
1 parent a6c40b1 commit 8439370
Show file tree
Hide file tree
Showing 5 changed files with 183 additions and 21 deletions.
27 changes: 16 additions & 11 deletions compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -396,17 +396,22 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
end TermTypeTest

object Term extends TermModule:
def betaReduce(tree: Term): Option[Term] =
tree match
case tpd.Block(Nil, expr) =>
for e <- betaReduce(expr) yield tpd.cpy.Block(tree)(Nil, e)
case tpd.Inlined(_, Nil, expr) =>
betaReduce(expr)
case _ =>
val tree1 = dotc.transform.BetaReduce(tree)
if tree1 eq tree then None
else Some(tree1.withSpan(tree.span))

def betaReduce(tree: Term): Option[Term] =
val tree1 = new dotty.tools.dotc.ast.tpd.TreeMap {
override def transform(tree: Tree)(using Context): Tree = tree match {
case tpd.Block(Nil, _) | tpd.Inlined(_, Nil, _) =>
super.transform(tree)
case tpd.Apply(sel @ tpd.Select(expr, nme), args) =>
val tree1 = cpy.Apply(tree)(cpy.Select(sel)(transform(expr), nme), args)
dotc.transform.BetaReduce(tree1).withSpan(tree.span)
case tpd.Apply(ta @ tpd.TypeApply(sel @ tpd.Select(expr: Apply, nme), tpts), args) =>
val tree1 = cpy.Apply(tree)(cpy.TypeApply(ta)(cpy.Select(sel)(transform(expr), nme), tpts), args)
dotc.transform.BetaReduce(tree1).withSpan(tree.span)
case _ =>
dotc.transform.BetaReduce(tree).withSpan(tree.span)
}
}.transform(tree)
if tree1 == tree then None else Some(tree1)
end Term

given TermMethods: TermMethods with
Expand Down
43 changes: 38 additions & 5 deletions library/src/scala/quoted/Expr.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,45 @@ abstract class Expr[+T] private[scala] ()
object Expr {

/** `e.betaReduce` returns an expression that is functionally equivalent to `e`,
* however if `e` is of the form `((y1, ..., yn) => e2)(e1, ..., en)`
* then it optimizes this the top most call by returning the result of beta-reducing the application.
* Otherwise returns `expr`.
* however if `e` is of the form `((y1, ..., yn) => e2)(e1, ..., en)`
* then it optimizes the top most call by returning the result of beta-reducing the application.
* Similarly, all outermost curried function applications will be beta-reduced, if possible.
* Otherwise returns `expr`.
*
* To retain semantics the argument `ei` is bound as `val yi = ei` and by-name arguments to `def yi = ei`.
* Some bindings may be elided as an early optimization.
* To retain semantics the argument `ei` is bound as `val yi = ei` and by-name arguments to `def yi = ei`.
* Some bindings may be elided as an early optimization.
*
* Example:
* ```scala sc:nocompile
* ((a: Int, b: Int) => a + b).apply(x, y)
* ```
* will be reduced to
* ```scala sc:nocompile
* val a = x
* val b = y
* a + b
* ```
*
* Generally:
* ```scala sc:nocompile
* ([X1, Y1, ...] => (x1, y1, ...) => ... => [Xn, Yn, ...] => (xn, yn, ...) => f[X1, Y1, ..., Xn, Yn, ...](x1, y1, ..., xn, yn, ...))).apply[Tx1, Ty1, ...](myX1, myY1, ...)....apply[Txn, Tyn, ...](myXn, myYn, ...)
* ```
* will be reduced to
* ```scala sc:nocompile
* type X1 = Tx1
* type Y1 = Ty1
* ...
* val x1 = myX1
* val y1 = myY1
* ...
* type Xn = Txn
* type Yn = Tyn
* ...
* val xn = myXn
* val yn = myYn
* ...
* f[X1, Y1, ..., Xn, Yn, ...](x1, y1, ..., xn, yn, ...)
* ```
*/
def betaReduce[T](expr: Expr[T])(using Quotes): Expr[T] =
import quotes.reflect.*
Expand Down
43 changes: 38 additions & 5 deletions library/src/scala/quoted/Quotes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -774,14 +774,47 @@ trait Quotes { self: runtime.QuoteUnpickler & runtime.QuoteMatching =>
/** Methods of the module object `val Term` */
trait TermModule { this: Term.type =>

/** Returns a term that is functionally equivalent to `t`,
/** Returns a term that is functionally equivalent to `t`,
* however if `t` is of the form `((y1, ..., yn) => e2)(e1, ..., en)`
* then it optimizes this the top most call by returning the `Some`
* with the result of beta-reducing the application.
* then it optimizes the top most call by returning `Some`
* with the result of beta-reducing the function application.
* Similarly, all outermost curried function applications will be beta-reduced, if possible.
* Otherwise returns `None`.
*
* To retain semantics the argument `ei` is bound as `val yi = ei` and by-name arguments to `def yi = ei`.
* Some bindings may be elided as an early optimization.
* To retain semantics the argument `ei` is bound as `val yi = ei` and by-name arguments to `def yi = ei`.
* Some bindings may be elided as an early optimization.
*
* Example:
* ```scala sc:nocompile
* ((a: Int, b: Int) => a + b).apply(x, y)
* ```
* will be reduced to
* ```scala sc:nocompile
* val a = x
* val b = y
* a + b
* ```
*
* Generally:
* ```scala sc:nocompile
* ([X1, Y1, ...] => (x1, y1, ...) => ... => [Xn, Yn, ...] => (xn, yn, ...) => f[X1, Y1, ..., Xn, Yn, ...](x1, y1, ..., xn, yn, ...))).apply[Tx1, Ty1, ...](myX1, myY1, ...)....apply[Txn, Tyn, ...](myXn, myYn, ...)
* ```
* will be reduced to
* ```scala sc:nocompile
* type X1 = Tx1
* type Y1 = Ty1
* ...
* val x1 = myX1
* val y1 = myY1
* ...
* type Xn = Txn
* type Yn = Tyn
* ...
* val xn = myXn
* val yn = myYn
* ...
* f[X1, Y1, ..., Xn, Yn, ...](x1, y1, ..., xn, yn, ...)
* ```
*/
def betaReduce(term: Term): Option[Term]

Expand Down
80 changes: 80 additions & 0 deletions tests/pos-macros/i17506/Macro_1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
class Foo
class Bar
class Baz

import scala.quoted._

def assertBetaReduction(using Quotes)(applied: Expr[Any], expected: String): quotes.reflect.Term =
import quotes.reflect._
val reducedMaybe = Term.betaReduce(applied.asTerm)
assert(reducedMaybe.isDefined)
val reduced = reducedMaybe.get
assert(reduced.show == expected,s"obtained: ${reduced.show}, expected: ${expected}")
reduced

inline def regularCurriedCtxFun2BetaReduceTest(inline f: Foo ?=> Bar ?=> Int): Unit =
${regularCurriedCtxFun2BetaReduceTestImpl('f)}
def regularCurriedCtxFun2BetaReduceTestImpl(f: Expr[Foo ?=> Bar ?=> Int])(using Quotes): Expr[Int] =
val expected =
"""|{
| val contextual$3: Bar = new Bar()
| val contextual$2: Foo = new Foo()
| 123
|}""".stripMargin
val applied = '{$f(using new Foo())(using new Bar())}
assertBetaReduction(applied, expected).asExprOf[Int]

inline def regularCurriedFun2BetaReduceTest(inline f: Foo => Bar => Int): Int =
${regularCurriedFun2BetaReduceTestImpl('f)}
def regularCurriedFun2BetaReduceTestImpl(f: Expr[Foo => Bar => Int])(using Quotes): Expr[Int] =
val expected =
"""|{
| val b: Bar = new Bar()
| val f: Foo = new Foo()
| 123
|}""".stripMargin
val applied = '{$f(new Foo())(new Bar())}
assertBetaReduction(applied, expected).asExprOf[Int]

inline def typeParamCurriedFun2BetaReduceTest(inline f: [A] => A => [B] => B => Unit): Unit =
${typeParamCurriedFun2BetaReduceTestImpl('f)}
def typeParamCurriedFun2BetaReduceTestImpl(f: Expr[[A] => (a: A) => [B] => (b: B) => Unit])(using Quotes): Expr[Unit] =
val expected =
"""|{
| type Y = Bar
| val y: Bar = new Bar()
| type X = Foo
| val x: Foo = new Foo()
| typeParamFun2[Y, X](y, x)
|}""".stripMargin
val applied = '{$f.apply[Foo](new Foo()).apply[Bar](new Bar())}
assertBetaReduction(applied, expected).asExprOf[Unit]

inline def regularCurriedFun3BetaReduceTest(inline f: Foo => Bar => Baz => Int): Int =
${regularCurriedFun3BetaReduceTestImpl('f)}
def regularCurriedFun3BetaReduceTestImpl(f: Expr[Foo => Bar => Baz => Int])(using Quotes): Expr[Int] =
val expected =
"""|{
| val i: Baz = new Baz()
| val b: Bar = new Bar()
| val f: Foo = new Foo()
| 123
|}""".stripMargin
val applied = '{$f(new Foo())(new Bar())(new Baz())}
assertBetaReduction(applied, expected).asExprOf[Int]

inline def typeParamCurriedFun3BetaReduceTest(inline f: [A] => A => [B] => B => [C] => C => Unit): Unit =
${typeParamCurriedFun3BetaReduceTestImpl('f)}
def typeParamCurriedFun3BetaReduceTestImpl(f: Expr[[A] => A => [B] => B => [C] => C => Unit])(using Quotes): Expr[Unit] =
val expected =
"""|{
| type Z = Baz
| val z: Baz = new Baz()
| type Y = Bar
| val y: Bar = new Bar()
| type X = Foo
| val x: Foo = new Foo()
| typeParamFun3[Z, Y, X](z, y, x)
|}""".stripMargin
val applied = '{$f.apply[Foo](new Foo()).apply[Bar](new Bar()).apply[Baz](new Baz())}
assertBetaReduction(applied, expected).asExprOf[Unit]
11 changes: 11 additions & 0 deletions tests/pos-macros/i17506/Test_2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
@main def run() =
def typeParamFun2[A, B](a: A, b: B): Unit = println(a.toString + " " + b.toString)
def typeParamFun3[A, B, C](a: A, b: B, c: C): Unit = println(a.toString + " " + b.toString)

regularCurriedCtxFun2BetaReduceTest((f: Foo) ?=> (b: Bar) ?=> 123)
regularCurriedCtxFun2BetaReduceTest(123)
regularCurriedFun2BetaReduceTest(((f: Foo) => (b: Bar) => 123))
typeParamCurriedFun2BetaReduceTest([X] => (x: X) => [Y] => (y: Y) => typeParamFun2[Y, X](y, x))

regularCurriedFun3BetaReduceTest((f: Foo) => (b: Bar) => (i: Baz) => 123)
typeParamCurriedFun3BetaReduceTest([X] => (x: X) => [Y] => (y: Y) => [Z] => (z: Z) => typeParamFun3[Z, Y, X](z, y, x))

0 comments on commit 8439370

Please sign in to comment.