diff --git a/backend/src/v2/component/launcher_v2.go b/backend/src/v2/component/launcher_v2.go index 2b5297c7b09..a80091699bc 100644 --- a/backend/src/v2/component/launcher_v2.go +++ b/backend/src/v2/component/launcher_v2.go @@ -148,6 +148,17 @@ func (l *LauncherV2) Execute(ctx context.Context) (err error) { } } glog.Infof("publish success.") + // At the end of the current task, we check the statuses of all tasks in + // the current DAG and update the DAG's status accordingly. + dag, err := l.metadataClient.GetDAG(ctx, execution.GetExecution().CustomProperties["parent_dag_id"].GetIntValue()) + if err != nil { + glog.Errorf("DAG Status Update: failed to get DAG: %s", err.Error()) + } + pipeline, _ := l.metadataClient.GetPipelineFromExecution(ctx, execution.GetID()) + err = l.metadataClient.UpdateDAGExecutionsState(ctx, dag, pipeline) + if err != nil { + glog.Errorf("failed to update DAG state: %s", err.Error()) + } }() executedStartedTime := time.Now().Unix() execution, err = l.prePublish(ctx) diff --git a/backend/src/v2/driver/driver.go b/backend/src/v2/driver/driver.go index 12d7a377182..960ec6148e8 100644 --- a/backend/src/v2/driver/driver.go +++ b/backend/src/v2/driver/driver.go @@ -17,10 +17,11 @@ import ( "context" "encoding/json" "fmt" - "github.com/kubeflow/pipelines/backend/src/v2/objectstore" "strconv" "time" + "github.com/kubeflow/pipelines/backend/src/v2/objectstore" + "github.com/golang/glog" "github.com/golang/protobuf/ptypes/timestamp" "github.com/google/uuid" @@ -125,6 +126,8 @@ func RootDAG(ctx context.Context, opts Options, mlmd *metadata.Client) (executio err = fmt.Errorf("driver.RootDAG(%s) failed: %w", opts.info(), err) } }() + b, _ := json.Marshal(opts) + glog.V(4).Info("RootDAG opts: ", string(b)) err = validateRootDAG(opts) if err != nil { return nil, err @@ -230,6 +233,8 @@ func Container(ctx context.Context, opts Options, mlmd *metadata.Client, cacheCl err = fmt.Errorf("driver.Container(%s) failed: %w", opts.info(), err) } }() + b, _ := json.Marshal(opts) + glog.V(4).Info("Container opts: ", string(b)) err = validateContainer(opts) if err != nil { return nil, err @@ -339,7 +344,7 @@ func Container(ctx context.Context, opts Options, mlmd *metadata.Client, cacheCl return execution, err } if opts.KubernetesExecutorConfig != nil { - dagTasks, err := mlmd.GetExecutionsInDAG(ctx, dag, pipeline) + dagTasks, err := mlmd.GetExecutionsInDAG(ctx, dag, pipeline, true) if err != nil { return execution, err } @@ -699,6 +704,8 @@ func DAG(ctx context.Context, opts Options, mlmd *metadata.Client) (execution *E err = fmt.Errorf("driver.DAG(%s) failed: %w", opts.info(), err) } }() + b, _ := json.Marshal(opts) + glog.V(4).Info("DAG opts: ", string(b)) err = validateDAG(opts) if err != nil { return nil, err @@ -749,6 +756,15 @@ func DAG(ctx context.Context, opts Options, mlmd *metadata.Client) (execution *E ecfg.ParentDagID = dag.Execution.GetID() ecfg.IterationIndex = iterationIndex ecfg.NotTriggered = !execution.WillTrigger() + + // Handle writing output parameters to MLMD. + ecfg.OutputParameters = opts.Component.GetDag().GetOutputs().GetParameters() + glog.V(4).Info("outputParameters: ", ecfg.OutputParameters) + + // Handle writing output artifacts to MLMD. + ecfg.OutputArtifacts = opts.Component.GetDag().GetOutputs().GetArtifacts() + glog.V(4).Info("outputArtifacts: ", ecfg.OutputArtifacts) + if opts.Task.GetArtifactIterator() != nil { return execution, fmt.Errorf("ArtifactIterator is not implemented") } @@ -793,6 +809,12 @@ func DAG(ctx context.Context, opts Options, mlmd *metadata.Client) (execution *E ecfg.IterationCount = &count execution.IterationCount = &count } + + glog.V(4).Info("pipeline: ", pipeline) + b, _ = json.Marshal(*ecfg) + glog.V(4).Info("ecfg: ", string(b)) + glog.V(4).Infof("dag: %v", dag) + // TODO(Bobgy): change execution state to pending, because this is driver, execution hasn't started. createdExecution, err := mlmd.CreateExecution(ctx, pipeline, ecfg) if err != nil { @@ -939,6 +961,8 @@ func resolveInputs(ctx context.Context, dag *metadata.DAG, iterationIndex *int, err = fmt.Errorf("failed to resolve inputs: %w", err) } }() + glog.V(4).Infof("dag: %v", dag) + glog.V(4).Infof("task: %v", task) inputParams, _, err := dag.Execution.GetParameters() if err != nil { return nil, err @@ -1102,20 +1126,11 @@ func resolveInputs(ctx context.Context, dag *metadata.DAG, iterationIndex *int, } return inputs, nil } - // get executions in context on demand - var tasksCache map[string]*metadata.Execution - getDAGTasks := func() (map[string]*metadata.Execution, error) { - if tasksCache != nil { - return tasksCache, nil - } - tasks, err := mlmd.GetExecutionsInDAG(ctx, dag, pipeline) - if err != nil { - return nil, err - } - tasksCache = tasks - return tasks, nil - } + + // Handle parameters. for name, paramSpec := range task.GetInputs().GetParameters() { + glog.V(4).Infof("name: %v", name) + glog.V(4).Infof("paramSpec: %v", paramSpec) paramError := func(err error) error { return fmt.Errorf("resolving input parameter %s with spec %s: %w", name, paramSpec, err) } @@ -1131,31 +1146,22 @@ func resolveInputs(ctx context.Context, dag *metadata.DAG, iterationIndex *int, } inputs.ParameterValues[name] = v + // This is the case where the input comes from the output of an upstream task. case *pipelinespec.TaskInputsSpec_InputParameterSpec_TaskOutputParameter: - taskOutput := paramSpec.GetTaskOutputParameter() - if taskOutput.GetProducerTask() == "" { - return nil, paramError(fmt.Errorf("producer task is empty")) - } - if taskOutput.GetOutputParameterKey() == "" { - return nil, paramError(fmt.Errorf("output parameter key is empty")) - } - tasks, err := getDAGTasks() - if err != nil { - return nil, paramError(err) - } - producer, ok := tasks[taskOutput.GetProducerTask()] - if !ok { - return nil, paramError(fmt.Errorf("cannot find producer task %q", taskOutput.GetProducerTask())) + cfg := resolveUpstreamOutputsConfig{ + ctx: ctx, + paramSpec: paramSpec, + dag: dag, + pipeline: pipeline, + mlmd: mlmd, + inputs: inputs, + name: name, + err: paramError, } - _, outputs, err := producer.GetParameters() - if err != nil { - return nil, paramError(fmt.Errorf("get producer output parameters: %w", err)) - } - param, ok := outputs[taskOutput.GetOutputParameterKey()] - if !ok { - return nil, paramError(fmt.Errorf("cannot find output parameter key %q in producer task %q", taskOutput.GetOutputParameterKey(), taskOutput.GetProducerTask())) + if err := resolveUpstreamParameters(cfg); err != nil { + return nil, err } - inputs.ParameterValues[name] = param + case *pipelinespec.TaskInputsSpec_InputParameterSpec_RuntimeValue: runtimeValue := paramSpec.GetRuntimeValue() switch t := runtimeValue.Value.(type) { @@ -1171,7 +1177,11 @@ func resolveInputs(ctx context.Context, dag *metadata.DAG, iterationIndex *int, return nil, paramError(fmt.Errorf("parameter spec of type %T not implemented yet", t)) } } + + // Handle artifacts. for name, artifactSpec := range task.GetInputs().GetArtifacts() { + glog.V(4).Infof("inputs: %#v", task.GetInputs()) + glog.V(4).Infof("artifacts: %#v", task.GetInputs().GetArtifacts()) artifactError := func(err error) error { return fmt.Errorf("failed to resolve input artifact %s with spec %s: %w", name, artifactSpec, err) } @@ -1188,43 +1198,300 @@ func resolveInputs(ctx context.Context, dag *metadata.DAG, iterationIndex *int, inputs.Artifacts[name] = v case *pipelinespec.TaskInputsSpec_InputArtifactSpec_TaskOutputArtifact: - taskOutput := artifactSpec.GetTaskOutputArtifact() - if taskOutput.GetProducerTask() == "" { - return nil, artifactError(fmt.Errorf("producer task is empty")) + cfg := resolveUpstreamOutputsConfig{ + ctx: ctx, + artifactSpec: artifactSpec, + dag: dag, + pipeline: pipeline, + mlmd: mlmd, + inputs: inputs, + name: name, + err: artifactError, } - if taskOutput.GetOutputArtifactKey() == "" { - return nil, artifactError(fmt.Errorf("output artifact key is empty")) + if err := resolveUpstreamArtifacts(cfg); err != nil { + return nil, err } - tasks, err := getDAGTasks() + default: + return nil, artifactError(fmt.Errorf("artifact spec of type %T not implemented yet", t)) + } + } + // TODO(Bobgy): validate executor inputs match component inputs definition + return inputs, nil +} + +// getDAGTasks is a recursive function that returns a map of all tasks across all DAGs in the context of nested DAGs. +func getDAGTasks( + ctx context.Context, + dag *metadata.DAG, + pipeline *metadata.Pipeline, + mlmd *metadata.Client, + flattenedTasks map[string]*metadata.Execution, +) (map[string]*metadata.Execution, error) { + if flattenedTasks == nil { + flattenedTasks = make(map[string]*metadata.Execution) + } + currentExecutionTasks, err := mlmd.GetExecutionsInDAG(ctx, dag, pipeline, true) + if err != nil { + return nil, err + } + for k, v := range currentExecutionTasks { + flattenedTasks[k] = v + } + for _, v := range currentExecutionTasks { + + if v.GetExecution().GetType() == "system.DAGExecution" { + // Iteration count is only applied when using ParallelFor, and in + // that scenario you're guaranteed to have redundant task names even + // within a single DAG, which results in an error when + // mlmd.GetExecutionsInDAG is called. ParallelFor outputs should be + // handled with dsl.Collected. + _, ok := v.GetExecution().GetCustomProperties()["iteration_count"] + if ok { + glog.Infof("Found a ParallelFor task, %v. Skipping it.", v.TaskName()) + continue + } + glog.V(4).Infof("Found a task, %v, with an execution type of system.DAGExecution. Adding its tasks to the task list.", v.TaskName()) + subDAG, err := mlmd.GetDAG(ctx, v.GetExecution().GetId()) + if err != nil { + return nil, err + } + // Pass the subDAG into a recursive call to getDAGTasks and update + // tasks to include the subDAG's tasks. + flattenedTasks, err = getDAGTasks(ctx, subDAG, pipeline, mlmd, flattenedTasks) if err != nil { - return nil, artifactError(err) + return nil, err } - producer, ok := tasks[taskOutput.GetProducerTask()] + } + } + + return flattenedTasks, nil +} + +// resolveUpstreamOutputsConfig is just a config struct used to store the input +// parameters of the resolveUpstreamParameters and resolveUpstreamArtifacts +// functions. +type resolveUpstreamOutputsConfig struct { + ctx context.Context + paramSpec *pipelinespec.TaskInputsSpec_InputParameterSpec + artifactSpec *pipelinespec.TaskInputsSpec_InputArtifactSpec + dag *metadata.DAG + pipeline *metadata.Pipeline + mlmd *metadata.Client + inputs *pipelinespec.ExecutorInput_Inputs + name string + err func(error) error +} + +// resolveUpstreamParameters resolves input parameters that come from upstream +// tasks. These tasks can be components/containers, which is relatively +// straightforward, or DAGs, in which case, we need to traverse the graph until +// we arrive at a component/container (since there can be n nested DAGs). +func resolveUpstreamParameters(cfg resolveUpstreamOutputsConfig) error { + taskOutput := cfg.paramSpec.GetTaskOutputParameter() + glog.V(4).Info("taskOutput: ", taskOutput) + producerTaskName := taskOutput.GetProducerTask() + if producerTaskName == "" { + return cfg.err(fmt.Errorf("producerTaskName is empty")) + } + outputParameterKey := taskOutput.GetOutputParameterKey() + if outputParameterKey == "" { + return cfg.err(fmt.Errorf("output parameter key is empty")) + } + + // Get a list of tasks for the current DAG first. + // The reason we use gatDAGTasks instead of mlmd.GetExecutionsInDAG is because the latter does not handle task name collisions in the map which results in a bunch of unhandled edge cases and test failures. + tasks, err := getDAGTasks(cfg.ctx, cfg.dag, cfg.pipeline, cfg.mlmd, nil) + if err != nil { + return cfg.err(err) + } + + producer, ok := tasks[producerTaskName] + if !ok { + return cfg.err(fmt.Errorf("producer task, %v, not in tasks", producerTaskName)) + } + glog.V(4).Info("producer: ", producer) + glog.V(4).Infof("tasks: %#v", tasks) + currentTask := producer + currentSubTaskMaybeDAG := true + // Continue looping until we reach a sub-task that is NOT a DAG. + for currentSubTaskMaybeDAG { + glog.V(4).Info("currentTask: ", currentTask.TaskName()) + // If the current task is a DAG: + if *currentTask.GetExecution().Type == "system.DAGExecution" { + // Since currentTask is a DAG, we need to deserialize its + // output parameter map so that we can look up its + // corresponding producer sub-task, reassign currentTask, + // and iterate through this loop again. + outputParametersCustomProperty, ok := currentTask.GetExecution().GetCustomProperties()["parameter_producer_task"] if !ok { - return nil, artifactError(fmt.Errorf("cannot find producer task %q", taskOutput.GetProducerTask())) + return cfg.err(fmt.Errorf("task, %v, does not have a parameter_producer_task custom property", currentTask.TaskName())) } - // TODO(Bobgy): cache results - outputs, err := mlmd.GetOutputArtifactsByExecutionId(ctx, producer.GetID()) + glog.V(4).Infof("outputParametersCustomProperty: %#v", outputParametersCustomProperty) + + dagOutputParametersMap := make(map[string]*pipelinespec.DagOutputsSpec_DagOutputParameterSpec) + glog.V(4).Infof("outputParametersCustomProperty: %v", outputParametersCustomProperty.GetStructValue()) + + for name, value := range outputParametersCustomProperty.GetStructValue().GetFields() { + outputSpec := &pipelinespec.DagOutputsSpec_DagOutputParameterSpec{} + err := protojson.Unmarshal([]byte(value.GetStringValue()), outputSpec) + if err != nil { + return err + } + dagOutputParametersMap[name] = outputSpec + } + + glog.V(4).Infof("Deserialized dagOutputParametersMap: %v", dagOutputParametersMap) + + // Support for the 2 DagOutputParameterSpec types: + // ValueFromParameter & ValueFromOneof + var subTaskName string + switch dagOutputParametersMap[outputParameterKey].Kind.(type) { + case *pipelinespec.DagOutputsSpec_DagOutputParameterSpec_ValueFromParameter: + subTaskName = dagOutputParametersMap[outputParameterKey].GetValueFromParameter().GetProducerSubtask() + outputParameterKey = dagOutputParametersMap[outputParameterKey].GetValueFromParameter().GetOutputParameterKey() + case *pipelinespec.DagOutputsSpec_DagOutputParameterSpec_ValueFromOneof: + // When OneOf is specified in a pipeline, the output of only 1 task is consumed even though there may be more than 1 task output set. In this case we will attempt to grab the first successful task output. + paramSelectors := dagOutputParametersMap[outputParameterKey].GetValueFromOneof().GetParameterSelectors() + glog.V(4).Infof("paramSelectors: %v", paramSelectors) + // Since we have the tasks map, we can iterate through the parameterSelectors if the ProducerSubTask is not present in the task map and then assign the new OutputParameterKey only if it exists. + successfulOneOfTask := false + for !successfulOneOfTask { + for _, paramSelector := range paramSelectors { + subTaskName = paramSelector.GetProducerSubtask() + glog.V(4).Infof("subTaskName from paramSelector: %v", subTaskName) + glog.V(4).Infof("outputParameterKey from paramSelector: %v", paramSelector.GetOutputParameterKey()) + if subTask, ok := tasks[subTaskName]; ok { + subTaskState := subTask.GetExecution().LastKnownState.String() + glog.V(4).Infof("subTask: %w , subTaskState: %v", subTaskName, subTaskState) + if subTaskState == "CACHED" || subTaskState == "COMPLETE" { + + outputParameterKey = paramSelector.GetOutputParameterKey() + successfulOneOfTask = true + break + } + } + } + if !successfulOneOfTask { + return cfg.err(fmt.Errorf("processing OneOf: No successful task found")) + } + } + } + glog.V(4).Infof("SubTaskName from outputParams: %v", subTaskName) + glog.V(4).Infof("OutputParameterKey from outputParams: %v", outputParameterKey) + if subTaskName == "" { + return cfg.err(fmt.Errorf("producer_subtask not in outputParams")) + } + glog.V(4).Infof( + "Overriding currentTask, %v, output with currentTask's producer_subtask, %v, output.", + currentTask.TaskName(), + subTaskName, + ) + currentTask, ok = tasks[subTaskName] + if !ok { + return cfg.err(fmt.Errorf("subTaskName, %v, not in tasks", subTaskName)) + } + + } else { + _, outputParametersCustomProperty, err := currentTask.GetParameters() if err != nil { - return nil, artifactError(err) + return err } - artifact, ok := outputs[taskOutput.GetOutputArtifactKey()] + cfg.inputs.ParameterValues[cfg.name] = outputParametersCustomProperty[outputParameterKey] + // Exit the loop. + currentSubTaskMaybeDAG = false + } + } + + return nil +} + +// resolveUpstreamArtifacts resolves input artifacts that come from upstream +// tasks. These tasks can be components/containers, which is relatively +// straightforward, or DAGs, in which case, we need to traverse the graph until +// we arrive at a component/container (since there can be n nested DAGs). +func resolveUpstreamArtifacts(cfg resolveUpstreamOutputsConfig) error { + glog.V(4).Infof("artifactSpec: %#v", cfg.artifactSpec) + taskOutput := cfg.artifactSpec.GetTaskOutputArtifact() + if taskOutput.GetProducerTask() == "" { + return cfg.err(fmt.Errorf("producer task is empty")) + } + if taskOutput.GetOutputArtifactKey() == "" { + cfg.err(fmt.Errorf("output artifact key is empty")) + } + tasks, err := cfg.mlmd.GetExecutionsInDAG(cfg.ctx, cfg.dag, cfg.pipeline, false) + if err != nil { + cfg.err(err) + } + + producer, ok := tasks[taskOutput.GetProducerTask()] + if !ok { + cfg.err( + fmt.Errorf("cannot find producer task %q", taskOutput.GetProducerTask()), + ) + } + glog.V(4).Info("producer: ", producer) + currentTask := producer + outputArtifactKey := taskOutput.GetOutputArtifactKey() + currentSubTaskMaybeDAG := true + // Continue looping until we reach a sub-task that is NOT a DAG. + for currentSubTaskMaybeDAG { + glog.V(4).Info("currentTask: ", currentTask.TaskName()) + // If the current task is a DAG: + if *currentTask.GetExecution().Type == "system.DAGExecution" { + // Get the sub-task. + outputArtifactsCustomProperty := currentTask.GetExecution().GetCustomProperties()["artifact_producer_task"] + // Deserialize the output artifacts. + var outputArtifacts map[string]*pipelinespec.DagOutputsSpec_DagOutputArtifactSpec + err := json.Unmarshal([]byte(outputArtifactsCustomProperty.GetStringValue()), &outputArtifacts) + if err != nil { + return err + } + glog.V(4).Infof("Deserialized outputArtifacts: %v", outputArtifacts) + // Adding support for multiple output artifacts + var subTaskName string + artifactSelectors := outputArtifacts[outputArtifactKey].GetArtifactSelectors() + + for _, v := range artifactSelectors { + glog.V(4).Infof("v: %v", v) + glog.V(4).Infof("v.ProducerSubtask: %v", v.ProducerSubtask) + glog.V(4).Infof("v.OutputArtifactKey: %v", v.OutputArtifactKey) + subTaskName = v.ProducerSubtask + outputArtifactKey = v.OutputArtifactKey + } + // If the sub-task is a DAG, reassign currentTask and run + // through the loop again. + currentTask = tasks[subTaskName] + // } + } else { + // Base case, currentTask is a container, not a DAG. + outputs, err := cfg.mlmd.GetOutputArtifactsByExecutionId(cfg.ctx, currentTask.GetID()) + if err != nil { + cfg.err(err) + } + glog.V(4).Infof("outputs: %#v", outputs) + artifact, ok := outputs[outputArtifactKey] if !ok { - return nil, artifactError(fmt.Errorf("cannot find output artifact key %q in producer task %q", taskOutput.GetOutputArtifactKey(), taskOutput.GetProducerTask())) + cfg.err( + fmt.Errorf( + "cannot find output artifact key %q in producer task %q", + taskOutput.GetOutputArtifactKey(), + taskOutput.GetProducerTask(), + ), + ) } runtimeArtifact, err := artifact.ToRuntimeArtifact() if err != nil { - return nil, artifactError(err) + cfg.err(err) } - inputs.Artifacts[name] = &pipelinespec.ArtifactList{ + cfg.inputs.Artifacts[cfg.name] = &pipelinespec.ArtifactList{ Artifacts: []*pipelinespec.RuntimeArtifact{runtimeArtifact}, } - default: - return nil, artifactError(fmt.Errorf("artifact spec of type %T not implemented yet", t)) + // Since we are in the base case, escape the loop. + currentSubTaskMaybeDAG = false } } - // TODO(Bobgy): validate executor inputs match component inputs definition - return inputs, nil + + return nil } func provisionOutputs(pipelineRoot, taskName string, outputsSpec *pipelinespec.ComponentOutputsSpec, outputUriSalt string) *pipelinespec.ExecutorInput_Outputs { diff --git a/backend/src/v2/metadata/client.go b/backend/src/v2/metadata/client.go index a292c1fe643..b261c2e1ee2 100644 --- a/backend/src/v2/metadata/client.go +++ b/backend/src/v2/metadata/client.go @@ -21,14 +21,15 @@ import ( "encoding/json" "errors" "fmt" - "github.com/kubeflow/pipelines/backend/src/common/util" - "github.com/kubeflow/pipelines/backend/src/v2/objectstore" "path" "strconv" "strings" "sync" "time" + "github.com/kubeflow/pipelines/backend/src/common/util" + "github.com/kubeflow/pipelines/backend/src/v2/objectstore" + "github.com/kubeflow/pipelines/api/v2alpha1/go/pipelinespec" "github.com/golang/glog" @@ -88,7 +89,9 @@ type ClientInterface interface { GetExecutions(ctx context.Context, ids []int64) ([]*pb.Execution, error) GetExecution(ctx context.Context, id int64) (*Execution, error) GetPipelineFromExecution(ctx context.Context, id int64) (*Pipeline, error) - GetExecutionsInDAG(ctx context.Context, dag *DAG, pipeline *Pipeline) (executionsMap map[string]*Execution, err error) + GetExecutionsInDAG(ctx context.Context, dag *DAG, pipeline *Pipeline, filter bool) (executionsMap map[string]*Execution, err error) + UpdateDAGExecutionsState(ctx context.Context, dag *DAG, pipeline *Pipeline) (err error) + PutDAGExecutionState(ctx context.Context, executionID int64, state pb.Execution_State) (err error) GetEventsByArtifactIDs(ctx context.Context, artifactIds []int64) ([]*pb.Event, error) GetArtifactName(ctx context.Context, artifactId int64) (string, error) GetArtifacts(ctx context.Context, ids []int64) ([]*pb.Artifact, error) @@ -134,6 +137,8 @@ type ExecutionConfig struct { NotTriggered bool // optional, not triggered executions will have CANCELED state. ParentDagID int64 // parent DAG execution ID. Only the root DAG does not have a parent DAG. InputParameters map[string]*structpb.Value + OutputParameters map[string]*pipelinespec.DagOutputsSpec_DagOutputParameterSpec + OutputArtifacts map[string]*pipelinespec.DagOutputsSpec_DagOutputArtifactSpec InputArtifactIDs map[string][]int64 IterationIndex *int // Index of the iteration. @@ -448,6 +453,8 @@ func getArtifactName(eventPath *pb.Event_Path) (string, error) { func (c *Client) PublishExecution(ctx context.Context, execution *Execution, outputParameters map[string]*structpb.Value, outputArtifacts []*OutputArtifact, state pb.Execution_State) error { e := execution.execution e.LastKnownState = state.Enum() + glog.V(4).Infof("outputParameters: %v", outputParameters) + glog.V(4).Infof("outputArtifacts: %v", outputArtifacts) if outputParameters != nil { // Record output parameters. @@ -500,22 +507,25 @@ func (c *Client) PublishExecution(ctx context.Context, execution *Execution, out // metadata keys const ( - keyDisplayName = "display_name" - keyTaskName = "task_name" - keyImage = "image" - keyPodName = "pod_name" - keyPodUID = "pod_uid" - keyNamespace = "namespace" - keyResourceName = "resource_name" - keyPipelineRoot = "pipeline_root" - keyStoreSessionInfo = "store_session_info" - keyCacheFingerPrint = "cache_fingerprint" - keyCachedExecutionID = "cached_execution_id" - keyInputs = "inputs" - keyOutputs = "outputs" - keyParentDagID = "parent_dag_id" // Parent DAG Execution ID. - keyIterationIndex = "iteration_index" - keyIterationCount = "iteration_count" + keyDisplayName = "display_name" + keyTaskName = "task_name" + keyImage = "image" + keyPodName = "pod_name" + keyPodUID = "pod_uid" + keyNamespace = "namespace" + keyResourceName = "resource_name" + keyPipelineRoot = "pipeline_root" + keyStoreSessionInfo = "store_session_info" + keyCacheFingerPrint = "cache_fingerprint" + keyCachedExecutionID = "cached_execution_id" + keyInputs = "inputs" + keyOutputs = "outputs" + keyParameterProducerTask = "parameter_producer_task" + keyOutputArtifacts = "output_artifacts" + keyArtifactProducerTask = "artifact_producer_task" + keyParentDagID = "parent_dag_id" // Parent DAG Execution ID. + keyIterationIndex = "iteration_index" + keyIterationCount = "iteration_count" ) // CreateExecution creates a new MLMD execution under the specified Pipeline. @@ -576,6 +586,40 @@ func (c *Client) CreateExecution(ctx context.Context, pipeline *Pipeline, config }, }} } + // We save the output parameter and output artifact relationships in MLMD in + // case they're provided by a sub-task so that we can follow the + // relationships and retrieve outputs downstream in components that depend + // on said outputs as inputs. + if config.OutputParameters != nil { + // Convert OutputParameters to a format that can be saved in MLMD. + glog.V(4).Info("outputParameters: ", config.OutputParameters) + outputParametersCustomPropertyProtoMap := make(map[string]*structpb.Value) + + for name, value := range config.OutputParameters { + if outputParameterProtoMsg, ok := interface{}(value).(proto.Message); ok { + glog.V(4).Infof("name: %v, value: %w", name, value) + glog.V(4).Info("protoMessage: ", outputParameterProtoMsg) + b, err := protojson.Marshal(outputParameterProtoMsg) + if err != nil { + return nil, err + } + outputValue, _ := structpb.NewValue(string(b)) + outputParametersCustomPropertyProtoMap[name] = outputValue + } + } + e.CustomProperties[keyParameterProducerTask] = &pb.Value{Value: &pb.Value_StructValue{ + StructValue: &structpb.Struct{ + Fields: outputParametersCustomPropertyProtoMap, + }, + }} + } + if config.OutputArtifacts != nil { + b, err := json.Marshal(config.OutputArtifacts) + if err != nil { + return nil, err + } + e.CustomProperties[keyArtifactProducerTask] = StringValue(string(b)) + } req := &pb.PutExecutionRequest{ Execution: e, @@ -640,6 +684,61 @@ func (c *Client) PrePublishExecution(ctx context.Context, execution *Execution, return execution, nil } +// UpdateDAGExecutionState checks all the statuses of the tasks in the given DAG, based on that it will update the DAG to the corresponding status if necessary. +func (c *Client) UpdateDAGExecutionsState(ctx context.Context, dag *DAG, pipeline *Pipeline) error { + tasks, err := c.GetExecutionsInDAG(ctx, dag, pipeline, true) + if err != nil { + return err + } + glog.V(4).Infof("tasks: %v", tasks) + glog.V(4).Infof("Checking Tasks' State") + completedTasks := 0 + failedTasks := 0 + totalTasks := len(tasks) + for _, task := range tasks { + taskState := task.GetExecution().LastKnownState.String() + glog.V(4).Infof("task: %s", task.TaskName()) + glog.V(4).Infof("task state: %s", taskState) + switch taskState { + case "FAILED": + failedTasks++ + case "COMPLETE": + completedTasks++ + case "CACHED": + completedTasks++ + case "CANCELED": + completedTasks++ + } + } + glog.V(4).Infof("completedTasks: %d", completedTasks) + glog.V(4).Infof("failedTasks: %d", failedTasks) + glog.V(4).Infof("totalTasks: %d", totalTasks) + + glog.Infof("Attempting to update DAG state") + if completedTasks == totalTasks { + c.PutDAGExecutionState(ctx, dag.Execution.GetID(), pb.Execution_COMPLETE) + } else if failedTasks > 0 { + c.PutDAGExecutionState(ctx, dag.Execution.GetID(), pb.Execution_FAILED) + } else { + glog.V(4).Infof("DAG is still running") + } + return nil +} + +// PutDAGExecutionState updates the given DAG Id to the state provided. +func (c *Client) PutDAGExecutionState(ctx context.Context, executionID int64, state pb.Execution_State) error { + + e, err := c.GetExecution(ctx, executionID) + if err != nil { + return err + } + e.execution.LastKnownState = state.Enum() + _, err = c.svc.PutExecution(ctx, &pb.PutExecutionRequest{ + Execution: e.execution, + }) + return err +} + // GetExecutions ... func (c *Client) GetExecutions(ctx context.Context, ids []int64) ([]*pb.Execution, error) { req := &pb.GetExecutionsByIDRequest{ExecutionIds: ids} @@ -704,7 +803,7 @@ func (c *Client) GetPipelineFromExecution(ctx context.Context, id int64) (*Pipel // GetExecutionsInDAG gets all executions in the DAG, and organize them // into a map, keyed by task name. -func (c *Client) GetExecutionsInDAG(ctx context.Context, dag *DAG, pipeline *Pipeline) (executionsMap map[string]*Execution, err error) { +func (c *Client) GetExecutionsInDAG(ctx context.Context, dag *DAG, pipeline *Pipeline, filter bool) (executionsMap map[string]*Execution, err error) { defer func() { if err != nil { err = fmt.Errorf("failed to get executions in %s: %w", dag.Info(), err) @@ -713,7 +812,12 @@ func (c *Client) GetExecutionsInDAG(ctx context.Context, dag *DAG, pipeline *Pip executionsMap = make(map[string]*Execution) // Documentation on query syntax: // https://github.com/google/ml-metadata/blob/839c3501a195d340d2855b6ffdb2c4b0b49862c9/ml_metadata/proto/metadata_store.proto#L831 - parentDAGFilter := fmt.Sprintf("custom_properties.parent_dag_id.int_value = %v", dag.Execution.GetID()) + // If filter is set to true, the MLMD call will only grab executions for the current DAG, else it would grab all the execution for the context which includes sub-DAGs. + parentDAGFilter := "" + if filter { + parentDAGFilter = fmt.Sprintf("custom_properties.parent_dag_id.int_value = %v", dag.Execution.GetID()) + } + // Note, because MLMD does not have index on custom properties right now, we // take a pipeline run context to limit the number of executions the DB needs to // iterate through to find sub-executions. @@ -732,14 +836,26 @@ func (c *Client) GetExecutionsInDAG(ctx context.Context, dag *DAG, pipeline *Pip } execs := res.GetExecutions() + glog.V(4).Infof("execs: %v", execs) for _, e := range execs { execution := &Execution{execution: e} taskName := execution.TaskName() if taskName == "" { - return nil, fmt.Errorf("empty task name for execution ID: %v", execution.GetID()) + if e.GetCustomProperties()[keyParentDagID] != nil { + return nil, fmt.Errorf("empty task name for execution ID: %v", execution.GetID()) + } + // When retrieving executions without the parentDAGFilter, the rootDAG execution is supplied but does not have an associated TaskName nor is the parentDagID set, therefore we won't include it in the executionsMap. + continue } existing, ok := executionsMap[taskName] if ok { + // TODO: The failure to handle this results in a specific edge + // case which has yet to be solved for. If you have three nested + // pipelines: A, which calls B, which calls C, and B and C share + // a task that A does not have but depends on in a producer + // subtask, when GetExecutionsInDAG is called, it will raise + // this error. + // TODO(Bobgy): to support retry, we need to handle multiple tasks with the same task name. return nil, fmt.Errorf("two tasks have the same task name %q, id1=%v id2=%v", taskName, existing.GetID(), execution.GetID()) } diff --git a/backend/src/v2/metadata/client_fake.go b/backend/src/v2/metadata/client_fake.go index 8e9b7b84677..beaddcc098c 100644 --- a/backend/src/v2/metadata/client_fake.go +++ b/backend/src/v2/metadata/client_fake.go @@ -19,6 +19,7 @@ package metadata import ( "context" + "github.com/kubeflow/pipelines/backend/src/v2/objectstore" "github.com/kubeflow/pipelines/api/v2alpha1/go/pipelinespec" @@ -64,10 +65,15 @@ func (c *FakeClient) GetPipelineFromExecution(ctx context.Context, id int64) (*P return nil, nil } -func (c *FakeClient) GetExecutionsInDAG(ctx context.Context, dag *DAG, pipeline *Pipeline) (executionsMap map[string]*Execution, err error) { +func (c *FakeClient) GetExecutionsInDAG(ctx context.Context, dag *DAG, pipeline *Pipeline, filter bool) (executionsMap map[string]*Execution, err error) { return nil, nil } - +func (c *FakeClient) UpdateDAGExecutionsState(ctx context.Context, dag *DAG, pipeline *Pipeline) (err error) { + return nil +} +func (c *FakeClient) PutDAGExecutionState(ctx context.Context, executionID int64, state pb.Execution_State) (err error) { + return nil +} func (c *FakeClient) GetEventsByArtifactIDs(ctx context.Context, artifactIds []int64) ([]*pb.Event, error) { return nil, nil } diff --git a/backend/src/v2/metadata/client_test.go b/backend/src/v2/metadata/client_test.go index 94f081b32b0..3cb5e1cc64c 100644 --- a/backend/src/v2/metadata/client_test.go +++ b/backend/src/v2/metadata/client_test.go @@ -311,7 +311,7 @@ func Test_DAG(t *testing.T) { t.Fatal(err) } rootDAG := &metadata.DAG{Execution: root} - rootChildren, err := client.GetExecutionsInDAG(ctx, rootDAG, pipeline) + rootChildren, err := client.GetExecutionsInDAG(ctx, rootDAG, pipeline, true) if err != nil { t.Fatal(err) } @@ -324,7 +324,7 @@ func Test_DAG(t *testing.T) { if rootChildren["task2"].GetID() != task2.GetID() { t.Errorf("executions[\"task2\"].GetID()=%v, task2.GetID()=%v. Not equal", rootChildren["task2"].GetID(), task2.GetID()) } - task1Children, err := client.GetExecutionsInDAG(ctx, &metadata.DAG{Execution: task1DAG}, pipeline) + task1Children, err := client.GetExecutionsInDAG(ctx, &metadata.DAG{Execution: task1DAG}, pipeline, true) if len(task1Children) != 1 { t.Errorf("len(task1Children)=%v, expect 1", len(task1Children)) } diff --git a/backend/src/v2/objectstore/object_store.go b/backend/src/v2/objectstore/object_store.go index 41b5118c49f..42ec6418c43 100644 --- a/backend/src/v2/objectstore/object_store.go +++ b/backend/src/v2/objectstore/object_store.go @@ -17,6 +17,13 @@ package objectstore import ( "context" "fmt" + "io" + "io/ioutil" + "os" + "path/filepath" + "regexp" + "strings" + "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/session" @@ -27,14 +34,8 @@ import ( "gocloud.dev/blob/s3blob" "gocloud.dev/gcp" "golang.org/x/oauth2/google" - "io" - "io/ioutil" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/client-go/kubernetes" - "os" - "path/filepath" - "regexp" - "strings" ) func OpenBucket(ctx context.Context, k8sClient kubernetes.Interface, namespace string, config *Config) (bucket *blob.Bucket, err error) { diff --git a/samples/v2/sample_test.py b/samples/v2/sample_test.py index d34599a3c18..2af7c4fba7d 100644 --- a/samples/v2/sample_test.py +++ b/samples/v2/sample_test.py @@ -11,20 +11,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os -import unittest -from concurrent.futures import ThreadPoolExecutor, as_completed +from concurrent.futures import as_completed +from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass +import inspect +import os from pprint import pprint from typing import List +import unittest +import component_with_optional_inputs +import hello_world import kfp from kfp.dsl.graph_component import GraphComponent -import component_with_optional_inputs +import pipeline_container_no_input import pipeline_with_env -import hello_world import producer_consumer_param -import pipeline_container_no_input +import subdagio import two_step_pipeline_containerized _MINUTE = 60 # seconds @@ -38,16 +41,21 @@ class TestCase: class SampleTest(unittest.TestCase): - _kfp_host_and_port = os.getenv('KFP_API_HOST_AND_PORT', 'http://localhost:8888') - _kfp_ui_and_port = os.getenv('KFP_UI_HOST_AND_PORT', 'http://localhost:8080') + _kfp_host_and_port = os.getenv('KFP_API_HOST_AND_PORT', + 'http://localhost:8888') + _kfp_ui_and_port = os.getenv('KFP_UI_HOST_AND_PORT', + 'http://localhost:8080') _client = kfp.Client(host=_kfp_host_and_port, ui_host=_kfp_ui_and_port) def test(self): test_cases: List[TestCase] = [ TestCase(pipeline_func=hello_world.pipeline_hello_world), - TestCase(pipeline_func=producer_consumer_param.producer_consumer_param_pipeline), - TestCase(pipeline_func=pipeline_container_no_input.pipeline_container_no_input), - TestCase(pipeline_func=two_step_pipeline_containerized.two_step_pipeline_containerized), + TestCase(pipeline_func=producer_consumer_param + .producer_consumer_param_pipeline), + TestCase(pipeline_func=pipeline_container_no_input + .pipeline_container_no_input), + TestCase(pipeline_func=two_step_pipeline_containerized + .two_step_pipeline_containerized), TestCase(pipeline_func=component_with_optional_inputs.pipeline), TestCase(pipeline_func=pipeline_with_env.pipeline_with_env), @@ -56,27 +64,45 @@ def test(self): # TestCase(pipeline_func=pipeline_with_volume.pipeline_with_volume), # TestCase(pipeline_func=pipeline_with_secret_as_volume.pipeline_secret_volume), # TestCase(pipeline_func=pipeline_with_secret_as_env.pipeline_secret_env), + TestCase(pipeline_func=subdagio.parameter.crust), + TestCase(pipeline_func=subdagio.parameter_cache.crust), + TestCase(pipeline_func=subdagio.mixed_parameters.crust), + TestCase( + pipeline_func=subdagio.multiple_parameters_namedtuple.crust), + TestCase(pipeline_func=subdagio.parameter_oneof.crust), + TestCase(pipeline_func=subdagio.artifact_cache.crust), + TestCase(pipeline_func=subdagio.artifact.crust), + TestCase( + pipeline_func=subdagio.multiple_artifacts_namedtuple.crust), ] with ThreadPoolExecutor() as executor: futures = [ - executor.submit(self.run_test_case, test_case.pipeline_func, test_case.timeout) - for test_case in test_cases + executor.submit(self.run_test_case, test_case.pipeline_func, + test_case.timeout) for test_case in test_cases ] for future in as_completed(futures): future.result() def run_test_case(self, pipeline_func: GraphComponent, timeout: int): with self.subTest(pipeline=pipeline_func, msg=pipeline_func.name): - run_result = self._client.create_run_from_pipeline_func(pipeline_func=pipeline_func) + print( + f'Running pipeline: {inspect.getmodule(pipeline_func.pipeline_func).__name__}/{pipeline_func.name}.' + ) + run_result = self._client.create_run_from_pipeline_func( + pipeline_func=pipeline_func) run_response = run_result.wait_for_run_completion(timeout) pprint(run_response.run_details) - print("Run details page URL:") - print(f"{self._kfp_ui_and_port}/#/runs/details/{run_response.run_id}") + print('Run details page URL:') + print( + f'{self._kfp_ui_and_port}/#/runs/details/{run_response.run_id}') - self.assertEqual(run_response.state, "SUCCEEDED") + self.assertEqual(run_response.state, 'SUCCEEDED') + print( + f'Pipeline, {inspect.getmodule(pipeline_func.pipeline_func).__name__}/{pipeline_func.name}, succeeded.' + ) if __name__ == '__main__': diff --git a/samples/v2/subdagio/__init__.py b/samples/v2/subdagio/__init__.py new file mode 100644 index 00000000000..024415d6bd2 --- /dev/null +++ b/samples/v2/subdagio/__init__.py @@ -0,0 +1,8 @@ +from subdagio import artifact +from subdagio import artifact_cache +from subdagio import mixed_parameters +from subdagio import multiple_artifacts_namedtuple +from subdagio import multiple_parameters_namedtuple +from subdagio import parameter +from subdagio import parameter_cache +from subdagio import parameter_oneof diff --git a/samples/v2/subdagio/artifact.py b/samples/v2/subdagio/artifact.py new file mode 100644 index 00000000000..8f425662a1b --- /dev/null +++ b/samples/v2/subdagio/artifact.py @@ -0,0 +1,47 @@ +import os + +from kfp import Client +from kfp import dsl + + +@dsl.component +def core_comp(dataset: dsl.Output[dsl.Dataset]): + with open(dataset.path, 'w') as f: + f.write('foo') + + +@dsl.component +def crust_comp(input: dsl.Dataset): + with open(input.path, 'r') as f: + print('input: ', f.read()) + + +@dsl.pipeline +def core() -> dsl.Dataset: + task = core_comp() + task.set_caching_options(False) + + return task.output + + +@dsl.pipeline +def mantle() -> dsl.Dataset: + dag_task = core() + dag_task.set_caching_options(False) + + return dag_task.output + + +@dsl.pipeline(name=os.path.basename(__file__).removesuffix('.py') + '-pipeline') +def crust(): + dag_task = mantle() + dag_task.set_caching_options(False) + + task = crust_comp(input=dag_task.output) + task.set_caching_options(False) + + +if __name__ == '__main__': + # Compiler().compile(pipeline_func=crust, package_path=f"{__file__.removesuffix('.py')}.yaml") + client = Client() + client.create_run_from_pipeline_func(crust) diff --git a/samples/v2/subdagio/artifact_cache.py b/samples/v2/subdagio/artifact_cache.py new file mode 100644 index 00000000000..5b52b25fb23 --- /dev/null +++ b/samples/v2/subdagio/artifact_cache.py @@ -0,0 +1,42 @@ +import os + +from kfp import Client +from kfp import dsl + + +@dsl.component +def core_comp(dataset: dsl.Output[dsl.Dataset]): + with open(dataset.path, 'w') as f: + f.write('foo') + + +@dsl.component +def crust_comp(input: dsl.Dataset): + with open(input.path, 'r') as f: + print('input: ', f.read()) + + +@dsl.pipeline +def core() -> dsl.Dataset: + task = core_comp() + + return task.output + + +@dsl.pipeline +def mantle() -> dsl.Dataset: + dag_task = core() + return dag_task.output + + +@dsl.pipeline(name=os.path.basename(__file__).removesuffix('.py') + '-pipeline') +def crust(): + dag_task = mantle() + + task = crust_comp(input=dag_task.output) + + +if __name__ == '__main__': + # Compiler().compile(pipeline_func=crust, package_path=f"{__file__.removesuffix('.py')}.yaml") + client = Client() + client.create_run_from_pipeline_func(crust) diff --git a/samples/v2/subdagio/mixed_parameters.py b/samples/v2/subdagio/mixed_parameters.py new file mode 100644 index 00000000000..0a660d335d9 --- /dev/null +++ b/samples/v2/subdagio/mixed_parameters.py @@ -0,0 +1,48 @@ +import os + +from kfp import Client +from kfp import dsl +from kfp.compiler import Compiler + + +@dsl.component +def core_comp() -> int: + return 1 + + +@dsl.component +def crust_comp(x: int, y: int): + print('sum :', x + y) + + +@dsl.pipeline +def core() -> int: + task = core_comp() + task.set_caching_options(False) + + return task.output + + +@dsl.pipeline +def mantle() -> int: + dag_task = core() + dag_task.set_caching_options(False) + + return dag_task.output + + +@dsl.pipeline(name=os.path.basename(__file__).removesuffix('.py') + '-pipeline') +def crust(): + dag_task = mantle() + dag_task.set_caching_options(False) + + task = crust_comp(x=2, y=dag_task.output) + task.set_caching_options(False) + + +if __name__ == '__main__': + Compiler().compile( + pipeline_func=crust, + package_path=f"{__file__.removesuffix('.py')}.yaml") + client = Client() + client.create_run_from_pipeline_func(crust) diff --git a/samples/v2/subdagio/multiple_artifacts_namedtuple.py b/samples/v2/subdagio/multiple_artifacts_namedtuple.py new file mode 100644 index 00000000000..7d2777d38b0 --- /dev/null +++ b/samples/v2/subdagio/multiple_artifacts_namedtuple.py @@ -0,0 +1,66 @@ +import os +from typing import NamedTuple + +from kfp import Client +from kfp import dsl + + +@dsl.component +def core_comp(ds1: dsl.Output[dsl.Dataset], ds2: dsl.Output[dsl.Dataset]): + with open(ds1.path, 'w') as f: + f.write('foo') + with open(ds2.path, 'w') as f: + f.write('bar') + + +@dsl.component +def crust_comp( + ds1: dsl.Dataset, + ds2: dsl.Dataset, +): + with open(ds1.path, 'r') as f: + print('ds1: ', f.read()) + with open(ds2.path, 'r') as f: + print('ds2: ', f.read()) + + +@dsl.pipeline +def core() -> NamedTuple( + 'outputs', + ds1=dsl.Dataset, + ds2=dsl.Dataset, +): # type: ignore + task = core_comp() + task.set_caching_options(False) + + return task.outputs + + +@dsl.pipeline +def mantle() -> NamedTuple( + 'outputs', + ds1=dsl.Dataset, + ds2=dsl.Dataset, +): # type: ignore + dag_task = core() + dag_task.set_caching_options(False) + + return dag_task.outputs + + +@dsl.pipeline(name=os.path.basename(__file__).removesuffix('.py') + '-pipeline') +def crust(): + dag_task = mantle() + dag_task.set_caching_options(False) + + task = crust_comp( + ds1=dag_task.outputs['ds1'], + ds2=dag_task.outputs['ds2'], + ) + task.set_caching_options(False) + + +if __name__ == '__main__': + # Compiler().compile(pipeline_func=crust, package_path=f"{__file__.removesuffix('.py')}.yaml") + client = Client() + client.create_run_from_pipeline_func(crust) diff --git a/samples/v2/subdagio/multiple_parameters_namedtuple.py b/samples/v2/subdagio/multiple_parameters_namedtuple.py new file mode 100644 index 00000000000..29699088554 --- /dev/null +++ b/samples/v2/subdagio/multiple_parameters_namedtuple.py @@ -0,0 +1,51 @@ +import os +from typing import NamedTuple + +from kfp import Client +from kfp import dsl + + +@dsl.component +def core_comp() -> NamedTuple('outputs', val1=str, val2=str): # type: ignore + outputs = NamedTuple('outputs', val1=str, val2=str) + return outputs('foo', 'bar') + + +@dsl.component +def crust_comp(val1: str, val2: str): + print('val1: ', val1) + print('val2: ', val2) + + +@dsl.pipeline +def core() -> NamedTuple('outputs', val1=str, val2=str): # type: ignore + task = core_comp() + task.set_caching_options(False) + + return task.outputs + + +@dsl.pipeline +def mantle() -> NamedTuple('outputs', val1=str, val2=str): # type: ignore + dag_task = core() + dag_task.set_caching_options(False) + + return dag_task.outputs + + +@dsl.pipeline(name=os.path.basename(__file__).removesuffix('.py') + '-pipeline') +def crust(): + dag_task = mantle() + dag_task.set_caching_options(False) + + task = crust_comp( + val1=dag_task.outputs['val1'], + val2=dag_task.outputs['val2'], + ) + task.set_caching_options(False) + + +if __name__ == '__main__': + # Compiler().compile(pipeline_func=crust, package_path=f"{__file__.removesuffix('.py')}.yaml") + client = Client() + client.create_run_from_pipeline_func(crust) diff --git a/samples/v2/subdagio/parameter.py b/samples/v2/subdagio/parameter.py new file mode 100644 index 00000000000..c00439dd1c8 --- /dev/null +++ b/samples/v2/subdagio/parameter.py @@ -0,0 +1,45 @@ +import os + +from kfp import Client +from kfp import dsl + + +@dsl.component +def core_comp() -> str: + return 'foo' + + +@dsl.component +def crust_comp(input: str): + print('input :', input) + + +@dsl.pipeline +def core() -> str: + task = core_comp() + task.set_caching_options(False) + + return task.output + + +@dsl.pipeline +def mantle() -> str: + dag_task = core() + dag_task.set_caching_options(False) + + return dag_task.output + + +@dsl.pipeline(name=os.path.basename(__file__).removesuffix('.py') + '-pipeline') +def crust(): + dag_task = mantle() + dag_task.set_caching_options(False) + + task = crust_comp(input=dag_task.output) + task.set_caching_options(False) + + +if __name__ == '__main__': + # Compiler().compile(pipeline_func=crust, package_path=f"{__file__.removesuffix('.py')}.yaml") + client = Client() + client.create_run_from_pipeline_func(crust) diff --git a/samples/v2/subdagio/parameter_cache.py b/samples/v2/subdagio/parameter_cache.py new file mode 100644 index 00000000000..9fe2402e2b8 --- /dev/null +++ b/samples/v2/subdagio/parameter_cache.py @@ -0,0 +1,40 @@ +import os + +from kfp import Client +from kfp import dsl + + +@dsl.component +def core_comp() -> str: + return 'foo' + + +@dsl.component +def crust_comp(input: str): + print('input :', input) + + +@dsl.pipeline +def core() -> str: + task = core_comp() + + return task.output + + +@dsl.pipeline +def mantle() -> str: + dag_task = core() + + return dag_task.output + + +@dsl.pipeline(name=os.path.basename(__file__).removesuffix('.py') + '-pipeline') +def crust(): + dag_task = mantle() + task = crust_comp(input=dag_task.output) + + +if __name__ == '__main__': + # Compiler().compile(pipeline_func=crust, package_path=f"{__file__.removesuffix('.py')}.yaml") + client = Client() + client.create_run_from_pipeline_func(crust) diff --git a/samples/v2/subdagio/parameter_oneof.py b/samples/v2/subdagio/parameter_oneof.py new file mode 100644 index 00000000000..6459c155ef6 --- /dev/null +++ b/samples/v2/subdagio/parameter_oneof.py @@ -0,0 +1,54 @@ +import os + +from kfp import Client +from kfp import dsl + +@dsl.component +def flip_coin() -> str: + import random + return 'heads' if random.randint(0, 1) == 0 else 'tails' + +@dsl.component +def core_comp(input: str) -> str: + print('input :', input) + return input + +@dsl.component +def core_output_comp(input: str, output_key: dsl.OutputPath(str)): + print('input :', input) + with open(output_key, 'w') as f: + f.write(input) + +@dsl.component +def crust_comp(input: str): + print('input :', input) + +@dsl.pipeline +def core() -> str: + flip_coin_task = flip_coin().set_caching_options(False) + with dsl.If(flip_coin_task.output == 'heads'): + t1 = core_comp(input='Got heads!').set_caching_options(False) + with dsl.Else(): + t2 = core_output_comp(input='Got tails!').set_caching_options(False) + return dsl.OneOf(t1.output, t2.outputs['output_key']) + +@dsl.pipeline +def mantle() -> str: + dag_task = core() + dag_task.set_caching_options(False) + + return dag_task.output + +@dsl.pipeline(name=os.path.basename(__file__).removesuffix('.py') + '-pipeline') +def crust(): + dag_task = mantle() + dag_task.set_caching_options(False) + + task = crust_comp(input=dag_task.output) + task.set_caching_options(False) + + +if __name__ == '__main__': + # Compiler().compile(pipeline_func=crust, package_path=f"{__file__.removesuffix('.py')}.yaml") + client = Client() + client.create_run_from_pipeline_func(crust) diff --git a/sdk/python/kfp/compiler/pipeline_spec_builder.py b/sdk/python/kfp/compiler/pipeline_spec_builder.py index 3f1575005da..ab638b0547a 100644 --- a/sdk/python/kfp/compiler/pipeline_spec_builder.py +++ b/sdk/python/kfp/compiler/pipeline_spec_builder.py @@ -1873,7 +1873,9 @@ def validate_pipeline_outputs_dict( f'Pipeline outputs may only be returned from the top level of the pipeline function scope. Got pipeline output from within the control flow group dsl.{channel.task.parent_task_group.__class__.__name__}.' ) else: - raise ValueError(f'Got unknown pipeline output: {channel}.') + raise ValueError( + f'Got unknown pipeline output, {channel}, of type {type(channel)}.' + ) def create_pipeline_spec( @@ -2013,13 +2015,20 @@ def convert_pipeline_outputs_to_dict( output name to PipelineChannel.""" if pipeline_outputs is None: return {} + elif isinstance(pipeline_outputs, dict): + # This condition is required to support the case where a nested pipeline + # returns a namedtuple but its output is converted into a dict by + # earlier invocations of this function (a few lines down). + return pipeline_outputs elif isinstance(pipeline_outputs, pipeline_channel.PipelineChannel): return {component_factory.SINGLE_OUTPUT_NAME: pipeline_outputs} elif isinstance(pipeline_outputs, tuple) and hasattr( pipeline_outputs, '_asdict'): return dict(pipeline_outputs._asdict()) else: - raise ValueError(f'Got unknown pipeline output: {pipeline_outputs}') + raise ValueError( + f'Got unknown pipeline output, {pipeline_outputs}, of type {type(pipeline_outputs)}.' + ) def write_pipeline_spec_to_file(