diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/Router.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/Router.scala index a0beb2febe..a149b854a9 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/Router.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/Router.scala @@ -452,34 +452,10 @@ class Router(formatOps: FormatOps) { val (nlCost, nlArrowPenalty) = if (!nl.isNL) (0, 0) else if (slbMod eq noSplitMod) (1 + nlPenalty, nlPenalty) - else { - def argClausePenalty(t: Term.ArgClause)(isFunc: Term => Boolean) = - t.values match { - case arg :: Nil => - if (isFunc(arg)) Some((nestedApplies(t), 2)) - else t.parent match { - case Some(p: Term.Apply) => - Some((nestedApplies(p), treeDepth(p.fun))) - case _ => None - } - case _ => None - } - leftOwner match { - case t: Term.ArgClause => - argClausePenalty(t)(_.is[Term.FunctionTerm]) - case t: Term => t.parent match { - case Some(p: Term.ArgClause) => argClausePenalty(p) { - case Term.Block((_: Term.FunctionTerm) :: Nil) => - getHead(t) eq ft - case _ => false - } - case _ => None - } - case _ => None + else getLambdaPenaltiesOnLeftBraceOnLeft(ft) + .fold((1, nlPenalty)) { case (sharedPenalty, herePenalty) => + (sharedPenalty + herePenalty, sharedPenalty) } - }.fold((1, nlPenalty)) { case (sharedPenalty, herePenalty) => - (sharedPenalty + herePenalty, sharedPenalty) - } val newlineBeforeClosingCurly = decideNewlinesOnlyBeforeClose(close) val nlPolicy = lambdaNLPolicy ==> newlineBeforeClosingCurly val nlSplit = Split(nl, nlCost, policy = nlPolicy) @@ -1949,17 +1925,16 @@ class Router(formatOps: FormatOps) { } val exclude = insideBracesBlock(ft, end, parensToo = true) .excludeCloseDelim - val bracesToParens = ftAfterRight.right.is[T.OpenDelim] && - (ftAfterRight.rightOwner match { - case t: Member.ArgClause => t.values.lengthCompare(1) == 0 - case t: Term.Block => t.parent.is[Member.ArgClause] - case _ => false - }) && { - implicit val ft: FT = next(ftAfterRight) - val rb = matchingRight(ftAfterRight) - getBracesToParensMod(rb, Space, isWithinBraces = true)._1 ne - Space - } + val ftNextAfterRight = next(ftAfterRight) + val singleArg = + if (!ftAfterRight.right.is[T.OpenDelim]) None + else getSingleArgOnLeftBraceOnLeft(ftNextAfterRight).map(_._2) + val bracesToParens = singleArg.isDefined && { + implicit val ft: FT = ftNextAfterRight + val rb = matchingRight(ftAfterRight) + getBracesToParensMod(rb, Space, isWithinBraces = true)._1 ne + Space + } val noSplit = Split(modSpace, 0) .withSingleLine(end, exclude = exclude) Seq(noSplit, nlSplitBase(if (bracesToParens) 0 else 1)) diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/util/TreeOps.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/util/TreeOps.scala index 902c1a61cb..d857d4125b 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/util/TreeOps.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/util/TreeOps.scala @@ -350,6 +350,42 @@ object TreeOps { math.max(res, treeDepth(t)) } + def getSingleArgOnLeftBraceOnLeft(ft: FT)(implicit + ftoks: FormatTokens, + ): Option[(Term.ArgClause, Stat)] = ft.leftOwner match { + case ac: Term.ArgClause => (ac.values match { + case (t: Term.Block) :: Nil if ftoks.getHead(t) eq ft => + getBlockSingleStat(t) + case t :: Nil => Some(t) + case _ => None + }).map(x => (ac, x)) + case t: Term => t.parent match { + case Some(ac: Term.ArgClause) if ac.values.lengthCompare(1) == 0 => + (t match { + case t: Term.Block if ftoks.getHead(ac) eq ft => + getBlockSingleStat(t) + case _ => Some(t) + }).map(x => (ac, x)) + case _ => None + } + case _ => None + } + + def getSingleArgLambdaPenalties( + ac: Term.ArgClause, + arg: Stat, + ): Option[(Int, Int)] = + if (arg.is[Term.FunctionTerm]) Some((nestedApplies(ac), 2)) + else ac.parent match { + case Some(p: Term.Apply) => Some((nestedApplies(p), treeDepth(p.fun))) + case _ => None + } + + def getLambdaPenaltiesOnLeftBraceOnLeft(ft: FormatToken)(implicit + ftoks: FormatTokens, + ): Option[(Int, Int)] = getSingleArgOnLeftBraceOnLeft(ft) + .flatMap((getSingleArgLambdaPenalties _).tupled) + final def canBreakAfterFuncArrow(func: Term.FunctionTerm)(implicit ftoks: FormatTokens, style: ScalafmtConfig,