diff --git a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js index fd4a48d2db330..474c453643365 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js +++ b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js @@ -173,9 +173,11 @@ function renderDagViz(forJob) { }); metadataContainer().selectAll(".barrier-rdd").each(function() { - var rddId = d3.select(this).text().trim(); - var clusterId = VizConstants.clusterPrefix + rddId; - svg.selectAll("g." + clusterId).classed("barrier", true) + var opId = d3.select(this).text().trim(); + var opClusterId = VizConstants.clusterPrefix + opId; + var stageId = $(this).parents(".stage-metadata").attr("stage-id"); + var stageClusterId = VizConstants.graphPrefix + stageId; + svg.selectAll("g[id=" + stageClusterId + "] g." + opClusterId).classed("barrier", true) }); resizeSvg(svg); diff --git a/core/src/test/scala/org/apache/spark/ui/RealBrowserUISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/RealBrowserUISeleniumSuite.scala index 06b6483717a65..4b018f69b1660 100644 --- a/core/src/test/scala/org/apache/spark/ui/RealBrowserUISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/RealBrowserUISeleniumSuite.scala @@ -100,6 +100,34 @@ abstract class RealBrowserUISeleniumSuite(val driverProp: String) } } + test("SPARK-31886: Color barrier execution mode RDD correctly") { + withSpark(newSparkContext()) { sc => + sc.parallelize(1 to 10).barrier.mapPartitions(identity).repartition(1).collect() + + eventually(timeout(10.seconds), interval(50.milliseconds)) { + goToUi(sc, "/jobs/job/?id=0") + webDriver.findElement(By.id("job-dag-viz")).click() + + val stage0 = webDriver.findElement(By.cssSelector("g[id='graph_0']")) + val stage1 = webDriver.findElement(By.cssSelector("g[id='graph_1']")) + val barrieredOps = webDriver.findElements(By.className("barrier-rdd")).iterator() + + while (barrieredOps.hasNext) { + val barrieredOpId = barrieredOps.next().getAttribute("innerHTML") + val foundInStage0 = + stage0.findElements( + By.cssSelector("g.barrier.cluster.cluster_" + barrieredOpId)) + assert(foundInStage0.size === 1) + + val foundInStage1 = + stage1.findElements( + By.cssSelector("g.barrier.cluster.cluster_" + barrieredOpId)) + assert(foundInStage1.size === 0) + } + } + } + } + /** * Create a test SparkContext with the SparkUI enabled. * It is safe to `get` the SparkUI directly from the SparkContext returned here.