diff --git a/compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala b/compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala index 256940645ec3..b837473ff22c 100644 --- a/compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala +++ b/compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala @@ -396,17 +396,22 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler end TermTypeTest object Term extends TermModule: - def betaReduce(tree: Term): Option[Term] = - tree match - case tpd.Block(Nil, expr) => - for e <- betaReduce(expr) yield tpd.cpy.Block(tree)(Nil, e) - case tpd.Inlined(_, Nil, expr) => - betaReduce(expr) - case _ => - val tree1 = dotc.transform.BetaReduce(tree) - if tree1 eq tree then None - else Some(tree1.withSpan(tree.span)) - + def betaReduce(tree: Term): Option[Term] = + val tree1 = new dotty.tools.dotc.ast.tpd.TreeMap { + override def transform(tree: Tree)(using Context): Tree = tree match { + case tpd.Block(Nil, _) | tpd.Inlined(_, Nil, _) => + super.transform(tree) + case tpd.Apply(sel @ tpd.Select(expr, nme), args) => + val tree1 = cpy.Apply(tree)(cpy.Select(sel)(transform(expr), nme), args) + dotc.transform.BetaReduce(tree1).withSpan(tree.span) + case tpd.Apply(ta @ tpd.TypeApply(sel @ tpd.Select(expr: Apply, nme), tpts), args) => + val tree1 = cpy.Apply(tree)(cpy.TypeApply(ta)(cpy.Select(sel)(transform(expr), nme), tpts), args) + dotc.transform.BetaReduce(tree1).withSpan(tree.span) + case _ => + dotc.transform.BetaReduce(tree).withSpan(tree.span) + } + }.transform(tree) + if tree1 == tree then None else Some(tree1) end Term given TermMethods: TermMethods with diff --git a/library/src/scala/quoted/Expr.scala b/library/src/scala/quoted/Expr.scala index 525f647eaaac..f1045e5bdaca 100644 --- a/library/src/scala/quoted/Expr.scala +++ b/library/src/scala/quoted/Expr.scala @@ -10,12 +10,45 @@ abstract class Expr[+T] private[scala] () object Expr { /** `e.betaReduce` returns an expression that is functionally equivalent to `e`, - * however if `e` is of the form `((y1, ..., yn) => e2)(e1, ..., en)` - * then it optimizes this the top most call by returning the result of beta-reducing the application. - * Otherwise returns `expr`. + * however if `e` is of the form `((y1, ..., yn) => e2)(e1, ..., en)` + * then it optimizes the top most call by returning the result of beta-reducing the application. + * Similarly, all outermost curried function applications will be beta-reduced, if possible. + * Otherwise returns `expr`. * - * To retain semantics the argument `ei` is bound as `val yi = ei` and by-name arguments to `def yi = ei`. - * Some bindings may be elided as an early optimization. + * To retain semantics the argument `ei` is bound as `val yi = ei` and by-name arguments to `def yi = ei`. + * Some bindings may be elided as an early optimization. + * + * Example: + * ```scala sc:nocompile + * ((a: Int, b: Int) => a + b).apply(x, y) + * ``` + * will be reduced to + * ```scala sc:nocompile + * val a = x + * val b = y + * a + b + * ``` + * + * Generally: + * ```scala sc:nocompile + * ([X1, Y1, ...] => (x1, y1, ...) => ... => [Xn, Yn, ...] => (xn, yn, ...) => f[X1, Y1, ..., Xn, Yn, ...](x1, y1, ..., xn, yn, ...))).apply[Tx1, Ty1, ...](myX1, myY1, ...)....apply[Txn, Tyn, ...](myXn, myYn, ...) + * ``` + * will be reduced to + * ```scala sc:nocompile + * type X1 = Tx1 + * type Y1 = Ty1 + * ... + * val x1 = myX1 + * val y1 = myY1 + * ... + * type Xn = Txn + * type Yn = Tyn + * ... + * val xn = myXn + * val yn = myYn + * ... + * f[X1, Y1, ..., Xn, Yn, ...](x1, y1, ..., xn, yn, ...) + * ``` */ def betaReduce[T](expr: Expr[T])(using Quotes): Expr[T] = import quotes.reflect.* diff --git a/library/src/scala/quoted/Quotes.scala b/library/src/scala/quoted/Quotes.scala index fa96b73551d1..55e66ff90da8 100644 --- a/library/src/scala/quoted/Quotes.scala +++ b/library/src/scala/quoted/Quotes.scala @@ -774,14 +774,47 @@ trait Quotes { self: runtime.QuoteUnpickler & runtime.QuoteMatching => /** Methods of the module object `val Term` */ trait TermModule { this: Term.type => - /** Returns a term that is functionally equivalent to `t`, + /** Returns a term that is functionally equivalent to `t`, * however if `t` is of the form `((y1, ..., yn) => e2)(e1, ..., en)` - * then it optimizes this the top most call by returning the `Some` - * with the result of beta-reducing the application. + * then it optimizes the top most call by returning `Some` + * with the result of beta-reducing the function application. + * Similarly, all outermost curried function applications will be beta-reduced, if possible. * Otherwise returns `None`. * - * To retain semantics the argument `ei` is bound as `val yi = ei` and by-name arguments to `def yi = ei`. - * Some bindings may be elided as an early optimization. + * To retain semantics the argument `ei` is bound as `val yi = ei` and by-name arguments to `def yi = ei`. + * Some bindings may be elided as an early optimization. + * + * Example: + * ```scala sc:nocompile + * ((a: Int, b: Int) => a + b).apply(x, y) + * ``` + * will be reduced to + * ```scala sc:nocompile + * val a = x + * val b = y + * a + b + * ``` + * + * Generally: + * ```scala sc:nocompile + * ([X1, Y1, ...] => (x1, y1, ...) => ... => [Xn, Yn, ...] => (xn, yn, ...) => f[X1, Y1, ..., Xn, Yn, ...](x1, y1, ..., xn, yn, ...))).apply[Tx1, Ty1, ...](myX1, myY1, ...)....apply[Txn, Tyn, ...](myXn, myYn, ...) + * ``` + * will be reduced to + * ```scala sc:nocompile + * type X1 = Tx1 + * type Y1 = Ty1 + * ... + * val x1 = myX1 + * val y1 = myY1 + * ... + * type Xn = Txn + * type Yn = Tyn + * ... + * val xn = myXn + * val yn = myYn + * ... + * f[X1, Y1, ..., Xn, Yn, ...](x1, y1, ..., xn, yn, ...) + * ``` */ def betaReduce(term: Term): Option[Term] diff --git a/tests/pos-macros/i17506/Macro_1.scala b/tests/pos-macros/i17506/Macro_1.scala new file mode 100644 index 000000000000..a66428a126be --- /dev/null +++ b/tests/pos-macros/i17506/Macro_1.scala @@ -0,0 +1,80 @@ +class Foo +class Bar +class Baz + +import scala.quoted._ + +def assertBetaReduction(using Quotes)(applied: Expr[Any], expected: String): quotes.reflect.Term = + import quotes.reflect._ + val reducedMaybe = Term.betaReduce(applied.asTerm) + assert(reducedMaybe.isDefined) + val reduced = reducedMaybe.get + assert(reduced.show == expected,s"obtained: ${reduced.show}, expected: ${expected}") + reduced + +inline def regularCurriedCtxFun2BetaReduceTest(inline f: Foo ?=> Bar ?=> Int): Unit = + ${regularCurriedCtxFun2BetaReduceTestImpl('f)} +def regularCurriedCtxFun2BetaReduceTestImpl(f: Expr[Foo ?=> Bar ?=> Int])(using Quotes): Expr[Int] = + val expected = + """|{ + | val contextual$3: Bar = new Bar() + | val contextual$2: Foo = new Foo() + | 123 + |}""".stripMargin + val applied = '{$f(using new Foo())(using new Bar())} + assertBetaReduction(applied, expected).asExprOf[Int] + +inline def regularCurriedFun2BetaReduceTest(inline f: Foo => Bar => Int): Int = + ${regularCurriedFun2BetaReduceTestImpl('f)} +def regularCurriedFun2BetaReduceTestImpl(f: Expr[Foo => Bar => Int])(using Quotes): Expr[Int] = + val expected = + """|{ + | val b: Bar = new Bar() + | val f: Foo = new Foo() + | 123 + |}""".stripMargin + val applied = '{$f(new Foo())(new Bar())} + assertBetaReduction(applied, expected).asExprOf[Int] + +inline def typeParamCurriedFun2BetaReduceTest(inline f: [A] => A => [B] => B => Unit): Unit = + ${typeParamCurriedFun2BetaReduceTestImpl('f)} +def typeParamCurriedFun2BetaReduceTestImpl(f: Expr[[A] => (a: A) => [B] => (b: B) => Unit])(using Quotes): Expr[Unit] = + val expected = + """|{ + | type Y = Bar + | val y: Bar = new Bar() + | type X = Foo + | val x: Foo = new Foo() + | typeParamFun2[Y, X](y, x) + |}""".stripMargin + val applied = '{$f.apply[Foo](new Foo()).apply[Bar](new Bar())} + assertBetaReduction(applied, expected).asExprOf[Unit] + +inline def regularCurriedFun3BetaReduceTest(inline f: Foo => Bar => Baz => Int): Int = + ${regularCurriedFun3BetaReduceTestImpl('f)} +def regularCurriedFun3BetaReduceTestImpl(f: Expr[Foo => Bar => Baz => Int])(using Quotes): Expr[Int] = + val expected = + """|{ + | val i: Baz = new Baz() + | val b: Bar = new Bar() + | val f: Foo = new Foo() + | 123 + |}""".stripMargin + val applied = '{$f(new Foo())(new Bar())(new Baz())} + assertBetaReduction(applied, expected).asExprOf[Int] + +inline def typeParamCurriedFun3BetaReduceTest(inline f: [A] => A => [B] => B => [C] => C => Unit): Unit = + ${typeParamCurriedFun3BetaReduceTestImpl('f)} +def typeParamCurriedFun3BetaReduceTestImpl(f: Expr[[A] => A => [B] => B => [C] => C => Unit])(using Quotes): Expr[Unit] = + val expected = + """|{ + | type Z = Baz + | val z: Baz = new Baz() + | type Y = Bar + | val y: Bar = new Bar() + | type X = Foo + | val x: Foo = new Foo() + | typeParamFun3[Z, Y, X](z, y, x) + |}""".stripMargin + val applied = '{$f.apply[Foo](new Foo()).apply[Bar](new Bar()).apply[Baz](new Baz())} + assertBetaReduction(applied, expected).asExprOf[Unit] diff --git a/tests/pos-macros/i17506/Test_2.scala b/tests/pos-macros/i17506/Test_2.scala new file mode 100644 index 000000000000..97a146ecba93 --- /dev/null +++ b/tests/pos-macros/i17506/Test_2.scala @@ -0,0 +1,11 @@ +@main def run() = + def typeParamFun2[A, B](a: A, b: B): Unit = println(a.toString + " " + b.toString) + def typeParamFun3[A, B, C](a: A, b: B, c: C): Unit = println(a.toString + " " + b.toString) + + regularCurriedCtxFun2BetaReduceTest((f: Foo) ?=> (b: Bar) ?=> 123) + regularCurriedCtxFun2BetaReduceTest(123) + regularCurriedFun2BetaReduceTest(((f: Foo) => (b: Bar) => 123)) + typeParamCurriedFun2BetaReduceTest([X] => (x: X) => [Y] => (y: Y) => typeParamFun2[Y, X](y, x)) + + regularCurriedFun3BetaReduceTest((f: Foo) => (b: Bar) => (i: Baz) => 123) + typeParamCurriedFun3BetaReduceTest([X] => (x: X) => [Y] => (y: Y) => [Z] => (z: Z) => typeParamFun3[Z, Y, X](z, y, x))