diff --git a/tasks/src/main/scala/dagr/tasks/ScatterGather.scala b/tasks/src/main/scala/dagr/tasks/ScatterGather.scala index 5d31f861..36a6425a 100644 --- a/tasks/src/main/scala/dagr/tasks/ScatterGather.scala +++ b/tasks/src/main/scala/dagr/tasks/ScatterGather.scala @@ -115,6 +115,17 @@ object ScatterGather { } } + /** Implementation of a Partitioner that takes the partitions from an existing Partitioner, and then groups them. */ + private class GroupByPartitioner[Result, Key](partitioner: Partitioner[Result], f: Result => Key) extends SimpleInJvmTask with Partitioner[(Key, Seq[Result])] { + var partitions: Option[Seq[(Key, Seq[Result])]] = None + override def run(): Unit = { + partitioner.partitions match { + case None => throw new IllegalStateException(s"partitioner.partitions called before partitions populated.") + case Some(_partition) => this.partitions = Some(_partition.groupBy(f).toSeq) + } + } + } + /** * Implementation of a Scatter that is just a thinly veiled wrapper around the Scatterer * being used to generate the set of scatters/partitions to operate on. @@ -126,8 +137,11 @@ object ScatterGather { override def gather[NextResult <: Task](f: Seq[Result] => NextResult): Gather[Result,NextResult] = throw new UnsupportedOperationException("gather not supported on an unmapped Scatter") - override def groupBy[Key](f: Result => Key) : Scatter[(Key, Seq[Result])] = - throw new UnsupportedOperationException("groupBy not supported on an unmapped Scatter") + override def groupBy[Key](f: Result => Key) : Scatter[(Key, Seq[Result])] = { + val grouper = new GroupByPartitioner[Result, Key](partitioner, f) + this ==> grouper.scatter + grouper.scatter + } override def flatMap[NextResult](f: Result => Scatter[NextResult]) : Scatter[NextResult] = { this.map(f).flatMap(identity) diff --git a/tasks/src/test/scala/dagr/tasks/ScatterGatherTests.scala b/tasks/src/test/scala/dagr/tasks/ScatterGatherTests.scala index 911d40da..d35db95b 100644 --- a/tasks/src/test/scala/dagr/tasks/ScatterGatherTests.scala +++ b/tasks/src/test/scala/dagr/tasks/ScatterGatherTests.scala @@ -107,6 +107,10 @@ class ScatterGatherTests extends UnitSpec with LazyLogging with BeforeAndAfterAl def run(): Unit = Io.writeLines(output, Seq(number.toString)) } + private case class WriteNumberTuple(numbers: (Int, Int), output: Path) extends SimpleInJvmTask { + def run(): Unit = Io.writeLines(output, Seq(Seq(numbers._1, numbers._2).map(_.toString).mkString("\t"))) + } + "ScatterGather" should "run a simple scatter-gather pipeline on files" in { val lines = Seq("one", "one two", "one two three", "one two three four", "one two three four five") val lengths = Seq(1,2,3,4,5) @@ -132,11 +136,7 @@ class ScatterGatherTests extends UnitSpec with LazyLogging with BeforeAndAfterAl val taskManager = buildTaskManager taskManager.addTask(pipeline) - taskManager.runToCompletion(true).foreach { case (task, info) => - if (TaskStatus.isTaskNotDone(info.status)) { - println(s"${task.name} $info") - } - } + taskManager.runToCompletion(true) val sum1 = Io.readLines(sumOfCounts).next().toInt val sum2 = Io.readLines(sumOfSquares).next().toInt @@ -351,4 +351,48 @@ class ScatterGatherTests extends UnitSpec with LazyLogging with BeforeAndAfterAl sumFlatMap shouldBe lengths.sum } + + it should "flatMap a scatter, then groupBy and map" in { + val lines = Seq("one", "one two", "one two three", "one two three four", "one two three four five") + + // setup the input and output + val input = tmp() + val countsByWordLengthOut = tmp() + Io.writeLines(input, lines) + + val pipeline = new Pipeline() { + override def build(): Unit = { + // the initial scatter: scatters across lines + val scatter: Scatter[Path] = Scatter(SplitByLine(input=input)) + + // scatter from flatMap: each line is scattered across words, then flatMap makes a scatter across all words (all lines) + val scatterByWordFlatMap: Scatter[Path] = scatter.flatMap { pathToLine => + val scatter = Scatter(SplitLineByWord(pathToLine)) + root ==> scatter + scatter + } + + // group by word length + val groupedByWordLength = scatterByWordFlatMap.groupBy { pathToWord => Io.readLines(pathToWord).next().length } + + // map: count how many occurrences of words of a given length + val countsByWordLength = groupedByWordLength.map { case (wordLength, tasks) => WriteNumberTuple(numbers=(wordLength, tasks.length), output=tmp()) } + + // gather: concatenate them all + countsByWordLength.gather { tasks => Concat(inputs = tasks.map(_.output), output = countsByWordLengthOut) } + + root ==> scatter + } + } + + val taskManager = buildTaskManager + taskManager.addTask(pipeline) + taskManager.runToCompletion(true) + + val outLines = Io.readLines(countsByWordLengthOut).toList + outLines.map { line => + val Array(wordLength: String, count: String) = line.split("\t") + (wordLength.toInt, count.toInt) + }.sortBy(_._1) should contain theSameElementsInOrderAs Seq((3, 9), (4, 3), (5, 3)) + } }