Skip to content
This repository has been archived by the owner on Jun 20, 2024. It is now read-only.

Commit

Permalink
Bug fix: scatter gather flatMap/groupBy fixup (#358)
Browse files Browse the repository at this point in the history
* groupBy should always be able to follow a flatMap
  • Loading branch information
nh13 authored Aug 15, 2019
1 parent b231013 commit fdf8223
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 7 deletions.
18 changes: 16 additions & 2 deletions tasks/src/main/scala/dagr/tasks/ScatterGather.scala
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,17 @@ object ScatterGather {
}
}

/** Implementation of a Partitioner that takes the partitions from an existing Partitioner, and then groups them. */
private class GroupByPartitioner[Result, Key](partitioner: Partitioner[Result], f: Result => Key) extends SimpleInJvmTask with Partitioner[(Key, Seq[Result])] {
var partitions: Option[Seq[(Key, Seq[Result])]] = None
override def run(): Unit = {
partitioner.partitions match {
case None => throw new IllegalStateException(s"partitioner.partitions called before partitions populated.")
case Some(_partition) => this.partitions = Some(_partition.groupBy(f).toSeq)
}
}
}

/**
* Implementation of a Scatter that is just a thinly veiled wrapper around the Scatterer
* being used to generate the set of scatters/partitions to operate on.
Expand All @@ -126,8 +137,11 @@ object ScatterGather {
override def gather[NextResult <: Task](f: Seq[Result] => NextResult): Gather[Result,NextResult] =
throw new UnsupportedOperationException("gather not supported on an unmapped Scatter")

override def groupBy[Key](f: Result => Key) : Scatter[(Key, Seq[Result])] =
throw new UnsupportedOperationException("groupBy not supported on an unmapped Scatter")
override def groupBy[Key](f: Result => Key) : Scatter[(Key, Seq[Result])] = {
val grouper = new GroupByPartitioner[Result, Key](partitioner, f)
this ==> grouper.scatter
grouper.scatter
}

override def flatMap[NextResult](f: Result => Scatter[NextResult]) : Scatter[NextResult] = {
this.map(f).flatMap(identity)
Expand Down
54 changes: 49 additions & 5 deletions tasks/src/test/scala/dagr/tasks/ScatterGatherTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ class ScatterGatherTests extends UnitSpec with LazyLogging with BeforeAndAfterAl
def run(): Unit = Io.writeLines(output, Seq(number.toString))
}

private case class WriteNumberTuple(numbers: (Int, Int), output: Path) extends SimpleInJvmTask {
def run(): Unit = Io.writeLines(output, Seq(Seq(numbers._1, numbers._2).map(_.toString).mkString("\t")))
}

"ScatterGather" should "run a simple scatter-gather pipeline on files" in {
val lines = Seq("one", "one two", "one two three", "one two three four", "one two three four five")
val lengths = Seq(1,2,3,4,5)
Expand All @@ -132,11 +136,7 @@ class ScatterGatherTests extends UnitSpec with LazyLogging with BeforeAndAfterAl

val taskManager = buildTaskManager
taskManager.addTask(pipeline)
taskManager.runToCompletion(true).foreach { case (task, info) =>
if (TaskStatus.isTaskNotDone(info.status)) {
println(s"${task.name} $info")
}
}
taskManager.runToCompletion(true)

val sum1 = Io.readLines(sumOfCounts).next().toInt
val sum2 = Io.readLines(sumOfSquares).next().toInt
Expand Down Expand Up @@ -351,4 +351,48 @@ class ScatterGatherTests extends UnitSpec with LazyLogging with BeforeAndAfterAl

sumFlatMap shouldBe lengths.sum
}

it should "flatMap a scatter, then groupBy and map" in {
val lines = Seq("one", "one two", "one two three", "one two three four", "one two three four five")

// setup the input and output
val input = tmp()
val countsByWordLengthOut = tmp()
Io.writeLines(input, lines)

val pipeline = new Pipeline() {
override def build(): Unit = {
// the initial scatter: scatters across lines
val scatter: Scatter[Path] = Scatter(SplitByLine(input=input))

// scatter from flatMap: each line is scattered across words, then flatMap makes a scatter across all words (all lines)
val scatterByWordFlatMap: Scatter[Path] = scatter.flatMap { pathToLine =>
val scatter = Scatter(SplitLineByWord(pathToLine))
root ==> scatter
scatter
}

// group by word length
val groupedByWordLength = scatterByWordFlatMap.groupBy { pathToWord => Io.readLines(pathToWord).next().length }

// map: count how many occurrences of words of a given length
val countsByWordLength = groupedByWordLength.map { case (wordLength, tasks) => WriteNumberTuple(numbers=(wordLength, tasks.length), output=tmp()) }

// gather: concatenate them all
countsByWordLength.gather { tasks => Concat(inputs = tasks.map(_.output), output = countsByWordLengthOut) }

root ==> scatter
}
}

val taskManager = buildTaskManager
taskManager.addTask(pipeline)
taskManager.runToCompletion(true)

val outLines = Io.readLines(countsByWordLengthOut).toList
outLines.map { line =>
val Array(wordLength: String, count: String) = line.split("\t")
(wordLength.toInt, count.toInt)
}.sortBy(_._1) should contain theSameElementsInOrderAs Seq((3, 9), (4, 3), (5, 3))
}
}

0 comments on commit fdf8223

Please sign in to comment.