diff --git a/core/src/main/scala/cats/data/NonEmptyList.scala b/core/src/main/scala/cats/data/NonEmptyList.scala index 60e7333924..03ef544dcd 100644 --- a/core/src/main/scala/cats/data/NonEmptyList.scala +++ b/core/src/main/scala/cats/data/NonEmptyList.scala @@ -7,6 +7,7 @@ import cats.syntax.order._ import scala.annotation.tailrec import scala.collection.immutable.TreeSet import scala.collection.mutable.ListBuffer +import scala.collection.{immutable, mutable} /** * A data type which represents a non empty list of A, with @@ -217,6 +218,29 @@ final case class NonEmptyList[+A](head: A, tail: List[A]) { NonEmptyList((head, 0), bldr.result) } + /** + * Groups elements inside of this `NonEmptyList` using a mapping function + * + * {{{ + * scala> import cats.data.NonEmptyList + * scala> val nel = NonEmptyList.of(12, -2, 3, -5) + * scala> nel.groupBy(_ >= 0) + * res0: Map[Boolean, cats.data.NonEmptyList[Int]] = Map(false -> NonEmptyList(-2, -5), true -> NonEmptyList(12, 3)) + * }}} + */ + def groupBy[B](f: A => B): Map[B, NonEmptyList[A]] = { + val m = mutable.Map.empty[B, mutable.Builder[A, List[A]]] + for { elem <- toList } { + m.getOrElseUpdate(f(elem), List.newBuilder[A]) += elem + } + val b = immutable.Map.newBuilder[B, NonEmptyList[A]] + for { (k, v) <- m } { + val head :: tail = v.result // we only create non empty list inside of the map `m` + b += ((k, NonEmptyList(head, tail))) + } + b.result + } + } object NonEmptyList extends NonEmptyListInstances { diff --git a/core/src/main/scala/cats/syntax/list.scala b/core/src/main/scala/cats/syntax/list.scala index 88defb39da..7c2a2c0aca 100644 --- a/core/src/main/scala/cats/syntax/list.scala +++ b/core/src/main/scala/cats/syntax/list.scala @@ -9,4 +9,6 @@ trait ListSyntax { final class ListOps[A](val la: List[A]) extends AnyVal { def toNel: Option[NonEmptyList[A]] = NonEmptyList.fromList(la) + def groupByNel[B](f: A => B): Map[B, NonEmptyList[A]] = + toNel.fold(Map.empty[B, NonEmptyList[A]])(_.groupBy(f)) } diff --git a/tests/src/test/scala/cats/tests/ListTests.scala b/tests/src/test/scala/cats/tests/ListTests.scala index d89c0251c9..308e067278 100644 --- a/tests/src/test/scala/cats/tests/ListTests.scala +++ b/tests/src/test/scala/cats/tests/ListTests.scala @@ -29,6 +29,12 @@ class ListTests extends CatsSuite { List.empty[Int].toNel should === (None) } + test("groupByNel should be consistent with groupBy")( + forAll { (fa: List[Int], f: Int => Int) => + fa.groupByNel(f).mapValues(_.toList) should === (fa.groupBy(f)) + } + ) + test("show"){ List(1, 2, 3).show should === ("List(1, 2, 3)") (Nil: List[Int]).show should === ("List()") diff --git a/tests/src/test/scala/cats/tests/NonEmptyListTests.scala b/tests/src/test/scala/cats/tests/NonEmptyListTests.scala index 166a40d53d..911377fed7 100644 --- a/tests/src/test/scala/cats/tests/NonEmptyListTests.scala +++ b/tests/src/test/scala/cats/tests/NonEmptyListTests.scala @@ -231,6 +231,12 @@ class NonEmptyListTests extends CatsSuite { nel.zipWithIndex.toList should === (nel.toList.zipWithIndex) } } + + test("NonEmptyList#groupBy is consistent with List#groupBy") { + forAll { (nel: NonEmptyList[Int], f: Int => Int) => + nel.groupBy(f).mapValues(_.toList) should === (nel.toList.groupBy(f)) + } + } } class ReducibleNonEmptyListCheck extends ReducibleCheck[NonEmptyList]("NonEmptyList") {