From b9982018fc67183bd8dc49cf6b4d8f1bdc34fb25 Mon Sep 17 00:00:00 2001
From: "jerome@aqueducthq.com" <jerome@aqueducthq.com>
Date: Thu, 27 Apr 2023 00:18:34 -0700
Subject: [PATCH] Do not schedule execution of save operator if other computer
 operators fail in a DAG Add unit test case Some linter related fixes for
 python llm module

---
 .../sdk/aqueduct_tests/flow_test.py           | 35 ++++++++++++++++++-
 scripts/run_linters.py                        |  1 +
 src/golang/lib/engine/aq_engine.go            | 34 ++++++++++++++++--
 .../utils/dolly_instruct_pipeline.py          |  1 -
 src/llm/aqueduct_llm/vicuna_7b.py             |  2 +-
 src/llm/setup.py                              |  2 +-
 6 files changed, 69 insertions(+), 6 deletions(-)

diff --git a/integration_tests/sdk/aqueduct_tests/flow_test.py b/integration_tests/sdk/aqueduct_tests/flow_test.py
index 9c0701caa..04394f1d5 100644
--- a/integration_tests/sdk/aqueduct_tests/flow_test.py
+++ b/integration_tests/sdk/aqueduct_tests/flow_test.py
@@ -1,9 +1,10 @@
+import time
 import uuid
 from datetime import datetime, timedelta
 
 import pandas as pd
 import pytest
-from aqueduct.constants.enums import ExecutionStatus
+from aqueduct.constants.enums import ExecutionStatus, LoadUpdateMode
 from aqueduct.error import InvalidRequestError, InvalidUserArgumentException
 
 import aqueduct
@@ -616,3 +617,35 @@ def noop():
         client.delete_flow(flow_id=flow.id(), flow_name="not a real flow")
 
     client.delete_flow(flow_name=flow.name())
+
+
+def test_flow_with_failed_compute_operators(
+    client, flow_name, data_integration, engine, data_validator
+):
+    """
+    Test if one or more compute operators fail, then the save/load operator does not succeed also.
+    """
+
+    @op
+    def bar(arg):
+        return 5 / 0
+
+    @op
+    def baz(arg):
+        time.sleep(10)
+        return arg
+
+    table_name = generate_table_name()
+    result = data_integration.sql("select * from hotel_reviews limit 5")
+    test_data = bar.lazy(baz.lazy(result))
+    save(data_integration, result, name=table_name, update_mode=LoadUpdateMode.REPLACE)
+
+    publish_flow_test(
+        client,
+        artifacts=[test_data, result],
+        name=flow_name(),
+        engine=engine,
+        expected_statuses=[ExecutionStatus.FAILED],
+    )
+
+    data_validator.check_saved_artifact_data_does_not_exist(result.id())
diff --git a/scripts/run_linters.py b/scripts/run_linters.py
index b6a7aa855..532b7af0d 100644
--- a/scripts/run_linters.py
+++ b/scripts/run_linters.py
@@ -27,6 +27,7 @@ def execute_command(args, cwd=None):
 
 def lint_python(cwd):
     execute_command(["black", join(cwd, "src/python"), "--line-length=100"])
+    execute_command(["black", join(cwd, "src/llm"), "--line-length=100"])
     execute_command(["black", join(cwd, "sdk"), "--line-length=100"])
     execute_command(["black", join(cwd, "integration_tests"), "--line-length=100"])
     execute_command(["black", join(cwd, "manual_qa_tests"), "--line-length=100"])
diff --git a/src/golang/lib/engine/aq_engine.go b/src/golang/lib/engine/aq_engine.go
index 1e9b95128..c8f88b312 100644
--- a/src/golang/lib/engine/aq_engine.go
+++ b/src/golang/lib/engine/aq_engine.go
@@ -17,6 +17,7 @@ import (
 	shared_utils "github.com/aqueducthq/aqueduct/lib/lib_utils"
 	"github.com/aqueducthq/aqueduct/lib/models"
 	"github.com/aqueducthq/aqueduct/lib/models/shared"
+	operator_model "github.com/aqueducthq/aqueduct/lib/models/shared/operator"
 	"github.com/aqueducthq/aqueduct/lib/models/shared/operator/param"
 	"github.com/aqueducthq/aqueduct/lib/repos"
 	"github.com/aqueducthq/aqueduct/lib/vault"
@@ -925,6 +926,11 @@ func (eng *aqEngine) execute(
 
 	// Kick off execution by starting all operators that don't have any inputs.
 	for _, op := range dag.Operators() {
+		log.Infof("Dag Operator %s [%d], Type: %s", op.Name(), len(dag.Operators()), op.Type())
+		if op.Type() == operator_model.LoadType {
+			log.Infof("Skipping save operator %s Type: %s", op.Name(), op.Type())
+			continue
+		}
 		if opToDependencyCount[op.ID()] == 0 {
 			inProgressOps[op.ID()] = op
 		}
@@ -952,12 +958,17 @@ func (eng *aqEngine) execute(
 
 	start := time.Now()
 
+	// We defer save operations until all other computer operations are completed successfully.
+	// This flag tracks whether the save operations are scheduled for execution.
+	loadOpsDone := false
+
 	for len(inProgressOps) > 0 {
 		if time.Since(start) > timeConfig.ExecTimeout {
 			return errors.Newf("Reached timeout %s waiting for workflow to complete.", timeConfig.ExecTimeout)
 		}
 
 		for _, op := range inProgressOps {
+			log.Infof("Operator in progress %s [%d], Type: %s", op.Name(), len(inProgressOps), op.Type())
 			if op.Dynamic() && !op.GetDynamicProperties().Prepared() {
 				err = dynamic.PrepareCluster(
 					ctx,
@@ -1079,19 +1090,26 @@ func (eng *aqEngine) execute(
 				}
 
 				for _, nextOp := range nextOps {
+
 					// Decrement the active dependency count for every downstream operator.
 					// Once this count reaches zero, we can schedule the next operator.
 					opToDependencyCount[nextOp.ID()] -= 1
 
 					if opToDependencyCount[nextOp.ID()] < 0 {
-						return errors.Newf("Internal error: operator %s has a negative dependnecy count.", op.Name())
+						return errors.Newf("Internal error: operator %s has a negative dependency count.", op.Name())
 					}
 
 					if opToDependencyCount[nextOp.ID()] == 0 {
 						// Defensive check: do not reschedule an already in-progress operator. This shouldn't actually
 						// matter because we only keep and update a single copy an on operator.
 						if _, ok := inProgressOps[nextOp.ID()]; !ok {
-							inProgressOps[nextOp.ID()] = nextOp
+							// In this pass only pick pending compute operations, and defer the save operations
+							// to the end.
+							if nextOp.Type() != operator_model.LoadType {
+								inProgressOps[nextOp.ID()] = nextOp
+							} else {
+								log.Infof("Skip load operator %s", nextOp.Name())
+							}
 						}
 					}
 				}
@@ -1099,6 +1117,18 @@ func (eng *aqEngine) execute(
 
 			time.Sleep(timeConfig.OperatorPollInterval)
 		}
+		// There are no more computer operations to run. Run the save (load) operations to persist
+		// artifacts to DB. The save operations are scheduled at the end so data is persisted only if
+		// all preceding compute operations are successful.
+		if len(inProgressOps) == 0 && !loadOpsDone {
+			for _, saveOp := range workflowDag.Operators() {
+				if saveOp.Type() == operator_model.LoadType {
+					log.Infof("Scheduling load operator %s for execution", saveOp.Name())
+					inProgressOps[saveOp.ID()] = saveOp
+				}
+			}
+			loadOpsDone = true
+		}
 	}
 
 	if len(completedOps) != len(dag.Operators()) {
diff --git a/src/llm/aqueduct_llm/utils/dolly_instruct_pipeline.py b/src/llm/aqueduct_llm/utils/dolly_instruct_pipeline.py
index 421154625..a29b412fb 100644
--- a/src/llm/aqueduct_llm/utils/dolly_instruct_pipeline.py
+++ b/src/llm/aqueduct_llm/utils/dolly_instruct_pipeline.py
@@ -5,7 +5,6 @@
 
 import numpy as np
 from transformers import Pipeline, PreTrainedTokenizer
-
 from transformers.utils import is_tf_available
 
 if is_tf_available():
diff --git a/src/llm/aqueduct_llm/vicuna_7b.py b/src/llm/aqueduct_llm/vicuna_7b.py
index 1645dc898..c5a3e993e 100644
--- a/src/llm/aqueduct_llm/vicuna_7b.py
+++ b/src/llm/aqueduct_llm/vicuna_7b.py
@@ -3,7 +3,7 @@
 
 import torch
 from fastchat.conversation import get_default_conv_template
-from fastchat.serve.inference import load_model, compute_skip_echo_len
+from fastchat.serve.inference import compute_skip_echo_len, load_model
 
 default_max_gpu_memory = "13GiB"
 default_temperature = 0.7
diff --git a/src/llm/setup.py b/src/llm/setup.py
index 12db42e72..4bb3cf452 100644
--- a/src/llm/setup.py
+++ b/src/llm/setup.py
@@ -1,4 +1,4 @@
-from setuptools import setup, find_packages
+from setuptools import find_packages, setup
 
 install_requires = open("requirements.txt").read().strip().split("\n")