diff --git a/tez-api/src/main/java/org/apache/tez/dag/api/EdgeProperty.java b/tez-api/src/main/java/org/apache/tez/dag/api/EdgeProperty.java index 1850060ece..bb0b7f1e48 100644 --- a/tez-api/src/main/java/org/apache/tez/dag/api/EdgeProperty.java +++ b/tez-api/src/main/java/org/apache/tez/dag/api/EdgeProperty.java @@ -261,6 +261,14 @@ public EdgeManagerPluginDescriptor getEdgeManagerDescriptor() { return edgeManagerDescriptor; } + /** + * Returns a new EdgeProperty with the given EdgeManagerPluginDescriptor. + */ + public EdgeProperty withDescriptor(EdgeManagerPluginDescriptor newDescriptor) { + return new EdgeProperty(newDescriptor, this.dataMovementType, this.dataSourceType, + this.schedulingType, this.outputDescriptor, this.inputDescriptor); + } + @Override public String toString() { return "{ " + dataMovementType + " : " + inputDescriptor.getClassName() diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/VertexImpl.java b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/VertexImpl.java index c7cf176af7..90cb94984f 100644 --- a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/VertexImpl.java +++ b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/VertexImpl.java @@ -1998,7 +1998,12 @@ private void setParallelismWrapper(int parallelism, VertexLocationHint vertexLoc Vertex sourceVertex = appContext.getCurrentDAG().getVertex(entry.getKey()); Edge edge = sourceVertices.get(sourceVertex); try { - edge.setEdgeProperty(entry.getValue()); + if (edge != null) { + edge.setEdgeProperty(entry.getValue()); + } else { + LOG.warn("edge = {}, sourceVertex = {}, entry.getValue() = {}", + edge, sourceVertex, entry.getValue()); + } } catch (Exception e) { throw new TezUncheckedException(e); } diff --git a/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/ShuffleVertexManager.java b/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/ShuffleVertexManager.java index b05c45ad96..e5fb612fc0 100644 --- a/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/ShuffleVertexManager.java +++ b/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/ShuffleVertexManager.java @@ -24,14 +24,8 @@ import com.google.protobuf.InvalidProtocolBufferException; import org.apache.tez.common.TezUtils; -import org.apache.tez.dag.api.EdgeManagerPluginContext; -import org.apache.tez.dag.api.EdgeManagerPluginDescriptor; -import org.apache.tez.dag.api.EdgeManagerPluginOnDemand; -import org.apache.tez.dag.api.TezUncheckedException; -import org.apache.tez.dag.api.UserPayload; -import org.apache.tez.dag.api.VertexManagerPluginContext; +import org.apache.tez.dag.api.*; import org.apache.tez.dag.api.VertexManagerPluginContext.ScheduleTaskRequest; -import org.apache.tez.dag.api.VertexManagerPluginDescriptor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.hadoop.classification.InterfaceAudience.Public; @@ -47,11 +41,7 @@ import java.io.IOException; import java.math.BigInteger; import java.nio.ByteBuffer; -import java.util.Arrays; -import java.util.Collections; -import java.util.Comparator; -import java.util.List; -import java.util.Map; +import java.util.*; /** * Starts scheduling tasks when number of completed source tasks crosses @@ -520,6 +510,30 @@ ReconfigVertexParams computeRouting() { for(Map.Entry entry : bipartiteItr) { entry.getValue().newDescriptor = descriptor; } + + // Additionally, update custom edges. + Map inputEdges = getContext().getInputVertexEdgeProperties(); + Map updatedEdges = new HashMap<>(); + for (Map.Entry entry : inputEdges.entrySet()) { + if (entry.getValue().getDataMovementType() == EdgeProperty.DataMovementType.CUSTOM) { + // Build a new custom edge manager configuration with updated parallelism. + CustomShuffleEdgeManagerConfig customConfig = new CustomShuffleEdgeManagerConfig( + currentParallelism, finalTaskParallelism, basePartitionRange, + (remainderRangeForLastShuffler > 0 ? remainderRangeForLastShuffler : basePartitionRange)); + EdgeManagerPluginDescriptor newDescriptor = EdgeManagerPluginDescriptor.create(CustomShuffleEdgeManager.class.getName()); + newDescriptor.setUserPayload(customConfig.toUserPayload()); + + // Update the EdgeProperty with the new descriptor. + EdgeProperty updatedProp = entry.getValue().withDescriptor(newDescriptor); + updatedEdges.put(entry.getKey(), updatedProp); + } + } + + // If any custom edges were updated, propagate the new configuration. + if (!updatedEdges.isEmpty()) { + getContext().reconfigureVertex(finalTaskParallelism, null, updatedEdges); + } + ReconfigVertexParams params = new ReconfigVertexParams(finalTaskParallelism, null); return params;