diff --git a/integration_tests/sdk/aqueduct_tests/flow_test.py b/integration_tests/sdk/aqueduct_tests/flow_test.py index 9c0701caa..04394f1d5 100644 --- a/integration_tests/sdk/aqueduct_tests/flow_test.py +++ b/integration_tests/sdk/aqueduct_tests/flow_test.py @@ -1,9 +1,10 @@ +import time import uuid from datetime import datetime, timedelta import pandas as pd import pytest -from aqueduct.constants.enums import ExecutionStatus +from aqueduct.constants.enums import ExecutionStatus, LoadUpdateMode from aqueduct.error import InvalidRequestError, InvalidUserArgumentException import aqueduct @@ -616,3 +617,35 @@ def noop(): client.delete_flow(flow_id=flow.id(), flow_name="not a real flow") client.delete_flow(flow_name=flow.name()) + + +def test_flow_with_failed_compute_operators( + client, flow_name, data_integration, engine, data_validator +): + """ + Test if one or more compute operators fail, then the save/load operator does not succeed also. + """ + + @op + def bar(arg): + return 5 / 0 + + @op + def baz(arg): + time.sleep(10) + return arg + + table_name = generate_table_name() + result = data_integration.sql("select * from hotel_reviews limit 5") + test_data = bar.lazy(baz.lazy(result)) + save(data_integration, result, name=table_name, update_mode=LoadUpdateMode.REPLACE) + + publish_flow_test( + client, + artifacts=[test_data, result], + name=flow_name(), + engine=engine, + expected_statuses=[ExecutionStatus.FAILED], + ) + + data_validator.check_saved_artifact_data_does_not_exist(result.id()) diff --git a/scripts/run_linters.py b/scripts/run_linters.py index b6a7aa855..532b7af0d 100644 --- a/scripts/run_linters.py +++ b/scripts/run_linters.py @@ -27,6 +27,7 @@ def execute_command(args, cwd=None): def lint_python(cwd): execute_command(["black", join(cwd, "src/python"), "--line-length=100"]) + execute_command(["black", join(cwd, "src/llm"), "--line-length=100"]) execute_command(["black", join(cwd, "sdk"), "--line-length=100"]) execute_command(["black", join(cwd, "integration_tests"), "--line-length=100"]) execute_command(["black", join(cwd, "manual_qa_tests"), "--line-length=100"]) diff --git a/src/golang/lib/engine/aq_engine.go b/src/golang/lib/engine/aq_engine.go index 1e9b95128..c8f88b312 100644 --- a/src/golang/lib/engine/aq_engine.go +++ b/src/golang/lib/engine/aq_engine.go @@ -17,6 +17,7 @@ import ( shared_utils "github.com/aqueducthq/aqueduct/lib/lib_utils" "github.com/aqueducthq/aqueduct/lib/models" "github.com/aqueducthq/aqueduct/lib/models/shared" + operator_model "github.com/aqueducthq/aqueduct/lib/models/shared/operator" "github.com/aqueducthq/aqueduct/lib/models/shared/operator/param" "github.com/aqueducthq/aqueduct/lib/repos" "github.com/aqueducthq/aqueduct/lib/vault" @@ -925,6 +926,11 @@ func (eng *aqEngine) execute( // Kick off execution by starting all operators that don't have any inputs. for _, op := range dag.Operators() { + log.Infof("Dag Operator %s [%d], Type: %s", op.Name(), len(dag.Operators()), op.Type()) + if op.Type() == operator_model.LoadType { + log.Infof("Skipping save operator %s Type: %s", op.Name(), op.Type()) + continue + } if opToDependencyCount[op.ID()] == 0 { inProgressOps[op.ID()] = op } @@ -952,12 +958,17 @@ func (eng *aqEngine) execute( start := time.Now() + // We defer save operations until all other computer operations are completed successfully. + // This flag tracks whether the save operations are scheduled for execution. + loadOpsDone := false + for len(inProgressOps) > 0 { if time.Since(start) > timeConfig.ExecTimeout { return errors.Newf("Reached timeout %s waiting for workflow to complete.", timeConfig.ExecTimeout) } for _, op := range inProgressOps { + log.Infof("Operator in progress %s [%d], Type: %s", op.Name(), len(inProgressOps), op.Type()) if op.Dynamic() && !op.GetDynamicProperties().Prepared() { err = dynamic.PrepareCluster( ctx, @@ -1079,19 +1090,26 @@ func (eng *aqEngine) execute( } for _, nextOp := range nextOps { + // Decrement the active dependency count for every downstream operator. // Once this count reaches zero, we can schedule the next operator. opToDependencyCount[nextOp.ID()] -= 1 if opToDependencyCount[nextOp.ID()] < 0 { - return errors.Newf("Internal error: operator %s has a negative dependnecy count.", op.Name()) + return errors.Newf("Internal error: operator %s has a negative dependency count.", op.Name()) } if opToDependencyCount[nextOp.ID()] == 0 { // Defensive check: do not reschedule an already in-progress operator. This shouldn't actually // matter because we only keep and update a single copy an on operator. if _, ok := inProgressOps[nextOp.ID()]; !ok { - inProgressOps[nextOp.ID()] = nextOp + // In this pass only pick pending compute operations, and defer the save operations + // to the end. + if nextOp.Type() != operator_model.LoadType { + inProgressOps[nextOp.ID()] = nextOp + } else { + log.Infof("Skip load operator %s", nextOp.Name()) + } } } } @@ -1099,6 +1117,18 @@ func (eng *aqEngine) execute( time.Sleep(timeConfig.OperatorPollInterval) } + // There are no more computer operations to run. Run the save (load) operations to persist + // artifacts to DB. The save operations are scheduled at the end so data is persisted only if + // all preceding compute operations are successful. + if len(inProgressOps) == 0 && !loadOpsDone { + for _, saveOp := range workflowDag.Operators() { + if saveOp.Type() == operator_model.LoadType { + log.Infof("Scheduling load operator %s for execution", saveOp.Name()) + inProgressOps[saveOp.ID()] = saveOp + } + } + loadOpsDone = true + } } if len(completedOps) != len(dag.Operators()) { diff --git a/src/llm/aqueduct_llm/utils/dolly_instruct_pipeline.py b/src/llm/aqueduct_llm/utils/dolly_instruct_pipeline.py index 421154625..a29b412fb 100644 --- a/src/llm/aqueduct_llm/utils/dolly_instruct_pipeline.py +++ b/src/llm/aqueduct_llm/utils/dolly_instruct_pipeline.py @@ -5,7 +5,6 @@ import numpy as np from transformers import Pipeline, PreTrainedTokenizer - from transformers.utils import is_tf_available if is_tf_available(): diff --git a/src/llm/aqueduct_llm/vicuna_7b.py b/src/llm/aqueduct_llm/vicuna_7b.py index 1645dc898..c5a3e993e 100644 --- a/src/llm/aqueduct_llm/vicuna_7b.py +++ b/src/llm/aqueduct_llm/vicuna_7b.py @@ -3,7 +3,7 @@ import torch from fastchat.conversation import get_default_conv_template -from fastchat.serve.inference import load_model, compute_skip_echo_len +from fastchat.serve.inference import compute_skip_echo_len, load_model default_max_gpu_memory = "13GiB" default_temperature = 0.7 diff --git a/src/llm/setup.py b/src/llm/setup.py index 12db42e72..4bb3cf452 100644 --- a/src/llm/setup.py +++ b/src/llm/setup.py @@ -1,4 +1,4 @@ -from setuptools import setup, find_packages +from setuptools import find_packages, setup install_requires = open("requirements.txt").read().strip().split("\n")