Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow to beta reduce curried function applications in quotes reflect #18121

Merged
merged 1 commit into from
Apr 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 16 additions & 11 deletions compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -396,17 +396,22 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
end TermTypeTest

object Term extends TermModule:
def betaReduce(tree: Term): Option[Term] =
tree match
case tpd.Block(Nil, expr) =>
for e <- betaReduce(expr) yield tpd.cpy.Block(tree)(Nil, e)
case tpd.Inlined(_, Nil, expr) =>
betaReduce(expr)
case _ =>
val tree1 = dotc.transform.BetaReduce(tree)
if tree1 eq tree then None
else Some(tree1.withSpan(tree.span))

def betaReduce(tree: Term): Option[Term] =
val tree1 = new dotty.tools.dotc.ast.tpd.TreeMap {
override def transform(tree: Tree)(using Context): Tree = tree match {
case tpd.Block(Nil, _) | tpd.Inlined(_, Nil, _) =>
super.transform(tree)
case tpd.Apply(sel @ tpd.Select(expr, nme), args) =>
val tree1 = cpy.Apply(tree)(cpy.Select(sel)(transform(expr), nme), args)
dotc.transform.BetaReduce(tree1).withSpan(tree.span)
case tpd.Apply(ta @ tpd.TypeApply(sel @ tpd.Select(expr: Apply, nme), tpts), args) =>
val tree1 = cpy.Apply(tree)(cpy.TypeApply(ta)(cpy.Select(sel)(transform(expr), nme), tpts), args)
dotc.transform.BetaReduce(tree1).withSpan(tree.span)
case _ =>
dotc.transform.BetaReduce(tree).withSpan(tree.span)
}
}.transform(tree)
if tree1 == tree then None else Some(tree1)
end Term

given TermMethods: TermMethods with
Expand Down
43 changes: 38 additions & 5 deletions library/src/scala/quoted/Expr.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,45 @@ abstract class Expr[+T] private[scala] ()
object Expr {

/** `e.betaReduce` returns an expression that is functionally equivalent to `e`,
* however if `e` is of the form `((y1, ..., yn) => e2)(e1, ..., en)`
* then it optimizes this the top most call by returning the result of beta-reducing the application.
* Otherwise returns `expr`.
* however if `e` is of the form `((y1, ..., yn) => e2)(e1, ..., en)`
* then it optimizes the top most call by returning the result of beta-reducing the application.
* Similarly, all outermost curried function applications will be beta-reduced, if possible.
* Otherwise returns `expr`.
*
* To retain semantics the argument `ei` is bound as `val yi = ei` and by-name arguments to `def yi = ei`.
* Some bindings may be elided as an early optimization.
* To retain semantics the argument `ei` is bound as `val yi = ei` and by-name arguments to `def yi = ei`.
* Some bindings may be elided as an early optimization.
*
* Example:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still belive this is to complex for a first-time user. I would call this the general rule.

Suggested change
* Example:
* Generally:

I would add a trivial beta reduction example before this.

Example:
```scala sc:nocompile
((a: Int, b: Int) => a + b).apply(x, y)
```
will be reduced to
```scala sc:nocompile
val a = x
val b = y
a + b
```

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see that you did that in the other file. We should have the same example in both documentations. I believe that my new example will be better for new users that are less experienced.

* ```scala sc:nocompile
* ((a: Int, b: Int) => a + b).apply(x, y)
* ```
* will be reduced to
* ```scala sc:nocompile
* val a = x
* val b = y
* a + b
* ```
*
* Generally:
* ```scala sc:nocompile
* ([X1, Y1, ...] => (x1, y1, ...) => ... => [Xn, Yn, ...] => (xn, yn, ...) => f[X1, Y1, ..., Xn, Yn, ...](x1, y1, ..., xn, yn, ...))).apply[Tx1, Ty1, ...](myX1, myY1, ...)....apply[Txn, Tyn, ...](myXn, myYn, ...)
* ```
* will be reduced to
* ```scala sc:nocompile
* type X1 = Tx1
* type Y1 = Ty1
* ...
* val x1 = myX1
* val y1 = myY1
* ...
* type Xn = Txn
* type Yn = Tyn
* ...
* val xn = myXn
* val yn = myYn
* ...
* f[X1, Y1, ..., Xn, Yn, ...](x1, y1, ..., xn, yn, ...)
* ```
*/
def betaReduce[T](expr: Expr[T])(using Quotes): Expr[T] =
import quotes.reflect.*
Expand Down
43 changes: 38 additions & 5 deletions library/src/scala/quoted/Quotes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -774,14 +774,47 @@ trait Quotes { self: runtime.QuoteUnpickler & runtime.QuoteMatching =>
/** Methods of the module object `val Term` */
trait TermModule { this: Term.type =>

/** Returns a term that is functionally equivalent to `t`,
/** Returns a term that is functionally equivalent to `t`,
* however if `t` is of the form `((y1, ..., yn) => e2)(e1, ..., en)`
* then it optimizes this the top most call by returning the `Some`
* with the result of beta-reducing the application.
* then it optimizes the top most call by returning `Some`
* with the result of beta-reducing the function application.
* Similarly, all outermost curried function applications will be beta-reduced, if possible.
* Otherwise returns `None`.
*
* To retain semantics the argument `ei` is bound as `val yi = ei` and by-name arguments to `def yi = ei`.
* Some bindings may be elided as an early optimization.
* To retain semantics the argument `ei` is bound as `val yi = ei` and by-name arguments to `def yi = ei`.
* Some bindings may be elided as an early optimization.
*
* Example:
nicolasstucki marked this conversation as resolved.
Show resolved Hide resolved
* ```scala sc:nocompile
* ((a: Int, b: Int) => a + b).apply(x, y)
* ```
* will be reduced to
* ```scala sc:nocompile
* val a = x
* val b = y
* a + b
* ```
*
* Generally:
* ```scala sc:nocompile
* ([X1, Y1, ...] => (x1, y1, ...) => ... => [Xn, Yn, ...] => (xn, yn, ...) => f[X1, Y1, ..., Xn, Yn, ...](x1, y1, ..., xn, yn, ...))).apply[Tx1, Ty1, ...](myX1, myY1, ...)....apply[Txn, Tyn, ...](myXn, myYn, ...)
* ```
* will be reduced to
* ```scala sc:nocompile
* type X1 = Tx1
* type Y1 = Ty1
* ...
* val x1 = myX1
* val y1 = myY1
* ...
* type Xn = Txn
* type Yn = Tyn
* ...
* val xn = myXn
* val yn = myYn
* ...
* f[X1, Y1, ..., Xn, Yn, ...](x1, y1, ..., xn, yn, ...)
* ```
*/
def betaReduce(term: Term): Option[Term]

Expand Down
80 changes: 80 additions & 0 deletions tests/pos-macros/i17506/Macro_1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
class Foo
class Bar
class Baz

import scala.quoted._

def assertBetaReduction(using Quotes)(applied: Expr[Any], expected: String): quotes.reflect.Term =
import quotes.reflect._
val reducedMaybe = Term.betaReduce(applied.asTerm)
assert(reducedMaybe.isDefined)
val reduced = reducedMaybe.get
assert(reduced.show == expected,s"obtained: ${reduced.show}, expected: ${expected}")
reduced

inline def regularCurriedCtxFun2BetaReduceTest(inline f: Foo ?=> Bar ?=> Int): Unit =
${regularCurriedCtxFun2BetaReduceTestImpl('f)}
def regularCurriedCtxFun2BetaReduceTestImpl(f: Expr[Foo ?=> Bar ?=> Int])(using Quotes): Expr[Int] =
val expected =
"""|{
| val contextual$3: Bar = new Bar()
| val contextual$2: Foo = new Foo()
| 123
|}""".stripMargin
val applied = '{$f(using new Foo())(using new Bar())}
assertBetaReduction(applied, expected).asExprOf[Int]

inline def regularCurriedFun2BetaReduceTest(inline f: Foo => Bar => Int): Int =
${regularCurriedFun2BetaReduceTestImpl('f)}
def regularCurriedFun2BetaReduceTestImpl(f: Expr[Foo => Bar => Int])(using Quotes): Expr[Int] =
val expected =
"""|{
| val b: Bar = new Bar()
| val f: Foo = new Foo()
| 123
|}""".stripMargin
val applied = '{$f(new Foo())(new Bar())}
assertBetaReduction(applied, expected).asExprOf[Int]

inline def typeParamCurriedFun2BetaReduceTest(inline f: [A] => A => [B] => B => Unit): Unit =
${typeParamCurriedFun2BetaReduceTestImpl('f)}
def typeParamCurriedFun2BetaReduceTestImpl(f: Expr[[A] => (a: A) => [B] => (b: B) => Unit])(using Quotes): Expr[Unit] =
val expected =
"""|{
| type Y = Bar
| val y: Bar = new Bar()
| type X = Foo
| val x: Foo = new Foo()
| typeParamFun2[Y, X](y, x)
|}""".stripMargin
val applied = '{$f.apply[Foo](new Foo()).apply[Bar](new Bar())}
assertBetaReduction(applied, expected).asExprOf[Unit]

inline def regularCurriedFun3BetaReduceTest(inline f: Foo => Bar => Baz => Int): Int =
${regularCurriedFun3BetaReduceTestImpl('f)}
def regularCurriedFun3BetaReduceTestImpl(f: Expr[Foo => Bar => Baz => Int])(using Quotes): Expr[Int] =
val expected =
"""|{
| val i: Baz = new Baz()
| val b: Bar = new Bar()
| val f: Foo = new Foo()
| 123
|}""".stripMargin
val applied = '{$f(new Foo())(new Bar())(new Baz())}
assertBetaReduction(applied, expected).asExprOf[Int]

inline def typeParamCurriedFun3BetaReduceTest(inline f: [A] => A => [B] => B => [C] => C => Unit): Unit =
${typeParamCurriedFun3BetaReduceTestImpl('f)}
def typeParamCurriedFun3BetaReduceTestImpl(f: Expr[[A] => A => [B] => B => [C] => C => Unit])(using Quotes): Expr[Unit] =
val expected =
"""|{
| type Z = Baz
| val z: Baz = new Baz()
| type Y = Bar
| val y: Bar = new Bar()
| type X = Foo
| val x: Foo = new Foo()
| typeParamFun3[Z, Y, X](z, y, x)
|}""".stripMargin
val applied = '{$f.apply[Foo](new Foo()).apply[Bar](new Bar()).apply[Baz](new Baz())}
assertBetaReduction(applied, expected).asExprOf[Unit]
11 changes: 11 additions & 0 deletions tests/pos-macros/i17506/Test_2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
@main def run() =
def typeParamFun2[A, B](a: A, b: B): Unit = println(a.toString + " " + b.toString)
def typeParamFun3[A, B, C](a: A, b: B, c: C): Unit = println(a.toString + " " + b.toString)

regularCurriedCtxFun2BetaReduceTest((f: Foo) ?=> (b: Bar) ?=> 123)
regularCurriedCtxFun2BetaReduceTest(123)
regularCurriedFun2BetaReduceTest(((f: Foo) => (b: Bar) => 123))
typeParamCurriedFun2BetaReduceTest([X] => (x: X) => [Y] => (y: Y) => typeParamFun2[Y, X](y, x))

regularCurriedFun3BetaReduceTest((f: Foo) => (b: Bar) => (i: Baz) => 123)
typeParamCurriedFun3BetaReduceTest([X] => (x: X) => [Y] => (y: Y) => [Z] => (z: Z) => typeParamFun3[Z, Y, X](z, y, x))
Loading