Skip to content

Commit

Permalink
Add support for subdags of subdags
Browse files Browse the repository at this point in the history
Signed-off-by: droctothorpe <[email protected]>
Co-authored-by: zazulam <[email protected]>
Co-authored-by: CarterFendley <[email protected]>
  • Loading branch information
3 people committed Sep 11, 2024
1 parent 1e62d0d commit a0a7b7b
Showing 1 changed file with 92 additions and 75 deletions.
167 changes: 92 additions & 75 deletions backend/src/v2/driver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -967,6 +967,44 @@ func validateNonRoot(opts Options) error {
return nil
}

// getDAGTasks gets all the tasks associated with the specified DAG and all of
// its subDAGs.
func getDAGTasks(
ctx context.Context,
dag *metadata.DAG,
pipeline *metadata.Pipeline,
mlmd *metadata.Client,
flattenedTasks map[string]*metadata.Execution,
) (map[string]*metadata.Execution, error) {
if flattenedTasks == nil {
flattenedTasks = make(map[string]*metadata.Execution)
}
currentExecutionTasks, err := mlmd.GetExecutionsInDAG(ctx, dag, pipeline)
if err != nil {
return nil, err
}
for k, v := range currentExecutionTasks {
flattenedTasks[k] = v
}
for _, v := range currentExecutionTasks {
if v.GetExecution().GetType() == "system.DAGExecution" {
glog.V(4).Infof("Found a task, %v, with an execution type of system.DAGExecution. Adding its tasks to the task list.", v.TaskName())
subDAG, err := mlmd.GetDAG(ctx, v.GetExecution().GetId())
if err != nil {
return nil, err
}
// Pass the subDAG into a recursive call to getDAGTasks and update
// tasks to include the subDAG's tasks.
flattenedTasks, err = getDAGTasks(ctx, subDAG, pipeline, mlmd, flattenedTasks)
if err != nil {
return nil, err
}
}
}

return flattenedTasks, nil
}

func resolveInputs(ctx context.Context, dag *metadata.DAG, iterationIndex *int, pipeline *metadata.Pipeline, task *pipelinespec.PipelineTaskSpec, inputsSpec *pipelinespec.ComponentInputsSpec, mlmd *metadata.Client, expr *expression.Expr) (inputs *pipelinespec.ExecutorInput_Inputs, err error) {
defer func() {
if err != nil {
Expand Down Expand Up @@ -1138,37 +1176,6 @@ func resolveInputs(ctx context.Context, dag *metadata.DAG, iterationIndex *int,
}
return inputs, nil
}
// get executions in context on demand
var tasksCache map[string]*metadata.Execution
getDAGTasks := func() (map[string]*metadata.Execution, error) {
if tasksCache != nil {
return tasksCache, nil
}
tasks, err := mlmd.GetExecutionsInDAG(ctx, dag, pipeline)
if err != nil {
return nil, err
}
// TODO: Make this recursive.
for _, v := range tasks {
if v.GetExecution().GetType() == "system.DAGExecution" {
glog.V(4).Infof("Found a task, %v, with an execution type of system.DAGExecution. Adding its tasks to the task list.", v.TaskName())
dag, err := mlmd.GetDAG(ctx, v.GetExecution().GetId())
if err != nil {
return nil, err
}
subdagTasks, err := mlmd.GetExecutionsInDAG(ctx, dag, pipeline)
if err != nil {
return nil, err
}
for k, v := range subdagTasks {
tasks[k] = v
}
}
}
tasksCache = tasks

return tasks, nil
}

for name, paramSpec := range task.GetInputs().GetParameters() {
glog.V(4).Infof("name: %v", name)
Expand Down Expand Up @@ -1199,60 +1206,70 @@ func resolveInputs(ctx context.Context, dag *metadata.DAG, iterationIndex *int,
if taskOutput.GetOutputParameterKey() == "" {
return nil, paramError(fmt.Errorf("output parameter key is empty"))
}
tasks, err := getDAGTasks()
tasks, err := getDAGTasks(ctx, dag, pipeline, mlmd, nil)
if err != nil {
return nil, paramError(err)
}

// The producer is the task that produces the output that we need to
// consume.
producer, ok := tasks[taskOutput.GetProducerTask()]
// If the producer is a DAG, AND its output / producer subtask is
// ALSO a DAG, then we need to cycle through this loop until we
// arrive at a non-DAG subtask and essentially bubble up that
// non-DAG subtask so that its value can be consumed.
producerSubTaskMaybeDAG := true
for producerSubTaskMaybeDAG {
// The producer is the task that produces the output that we need to
// consume.
producer := tasks[taskOutput.GetProducerTask()]

glog.V(4).Info("producer: ", producer)
glog.V(4).Info("producer: ", producer)

// Get the producer's outputs.
_, producerOutputs, err := producer.GetParameters()
if err != nil {
return nil, paramError(fmt.Errorf("get producer output parameters: %w", err))
}
glog.V(4).Info("producer output parameters: ", producerOutputs)
// Deserialize them.
var producerOutputsMap map[string]string
b, err := producerOutputs["Output"].GetStructValue().MarshalJSON()
if err != nil {
return nil, err
}
json.Unmarshal(b, &producerOutputsMap)
glog.V(4).Info("producerOutputsMap: ", producerOutputsMap)

// If the producer's output includes a producer subtask, which means
// that the producer is a DAG that is getting its output from one of
// the tasks in the DAG, then we want to roll up the output from the
// producer subtask to the producer, so that the downstream logic
// can retrieve it appropriately.
if producerSubTask, ok := producerOutputsMap["producer_subtask"]; ok {
glog.V(4).Infof(
"Overriding producer task, %v, output with producer_subtask, %v, output.",
producer.TaskName(),
producerSubTask,
)
_, producerOutputs, err = tasks[producerSubTask].GetParameters()
// Get the producer's outputs.
_, producerOutputs, err := producer.GetParameters()
if err != nil {
return nil, paramError(fmt.Errorf("get producer output parameters: %w", err))
}
glog.V(4).Info("producer output parameters: ", producerOutputs)
// Deserialize them.
var producerOutputsMap map[string]string
b, err := producerOutputs["Output"].GetStructValue().MarshalJSON()
if err != nil {
return nil, err
}
glog.V(4).Info("producerSubTask output parameters: ", producerOutputs)
// The only reason we're updating this is to make the downstream
// logging more accurate.
taskOutput.ProducerTask = producerOutputsMap["producer_subtask"]
json.Unmarshal(b, &producerOutputsMap)
glog.V(4).Info("producerOutputsMap: ", producerOutputsMap)

// If the producer's output includes a producer subtask, which means
// that the producer is a DAG that is getting its output from one of
// the tasks in the DAG, then we want to roll up the output from the
// producer subtask to the producer, so that the downstream logic
// can retrieve it appropriately.
if producerSubTask, ok := producerOutputsMap["producer_subtask"]; ok {
glog.V(4).Infof(
"Overriding producer task, %v, output with producer_subtask, %v, output.",
producer.TaskName(),
producerSubTask,
)
_, producerOutputs, err = tasks[producerSubTask].GetParameters()
if err != nil {
return nil, err
}
glog.V(4).Info("producerSubTask output parameters: ", producerOutputs)
// The only reason we're updating this is to make the downstream
// logging more accurate.
taskOutput.ProducerTask = producerOutputsMap["producer_subtask"]
// Grab the value of the producer output.
producerOutputValue, ok := producerOutputs[taskOutput.GetOutputParameterKey()]
if !ok {
return nil, paramError(fmt.Errorf("cannot find output parameter key %q in producer task %q", taskOutput.GetOutputParameterKey(), taskOutput.GetProducerTask()))
}
// Update the input to be the producer output value.
inputs.ParameterValues[name] = producerOutputValue
} else {
// The producer subtask is not a DAG, so we exit the loop.
producerSubTaskMaybeDAG = false
}
}

// Grab the value of the producer output.
producerOutputValue, ok := producerOutputs[taskOutput.GetOutputParameterKey()]
if !ok {
return nil, paramError(fmt.Errorf("cannot find output parameter key %q in producer task %q", taskOutput.GetOutputParameterKey(), taskOutput.GetProducerTask()))
}
// Update the input to be the producer output value.
inputs.ParameterValues[name] = producerOutputValue
case *pipelinespec.TaskInputsSpec_InputParameterSpec_RuntimeValue:
runtimeValue := paramSpec.GetRuntimeValue()
switch t := runtimeValue.Value.(type) {
Expand Down Expand Up @@ -1292,7 +1309,7 @@ func resolveInputs(ctx context.Context, dag *metadata.DAG, iterationIndex *int,
if taskOutput.GetOutputArtifactKey() == "" {
return nil, artifactError(fmt.Errorf("output artifact key is empty"))
}
tasks, err := getDAGTasks()
tasks, err := getDAGTasks(ctx, dag, pipeline, mlmd, nil)
if err != nil {
return nil, artifactError(err)
}
Expand Down

0 comments on commit a0a7b7b

Please sign in to comment.