Skip to content

Fixes Airflow to Aqueduct syncing bug #1347

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 26 additions & 3 deletions src/golang/lib/airflow/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"

"github.com/apache/airflow-client-go/airflow"
"github.com/aqueducthq/aqueduct/lib/errors"
"github.com/aqueducthq/aqueduct/lib/workflow/operator/connector/auth"
)

Expand Down Expand Up @@ -98,11 +99,33 @@ func (c *client) getTaskStates(dagId string, dagRunId string) (map[string]airflo
return taskIdToState, nil
}

// isDAGPaused returns whether or not the specified DAG is paused.
func (c *client) isDAGPaused(dagID string) (bool, error) {
dag, err := c.getDag(dagID)
if err != nil {
return false, err
}

return dag.GetIsPaused(), nil
}

// trigerDAGRun triggers a new DAGRun for the dag specified.
func (c *client) triggerDAGRun(dagId string) error {
request := c.apiClient.DAGRunApi.PostDagRun(c.ctx, dagId)
// It first ensures that the DAG is not paused.
func (c *client) triggerDAGRun(dagID string) error {
// Check if DAG is paused
paused, err := c.isDAGPaused(dagID)
if err != nil {
return err
}

if paused {
// TODO ENG-3002: Automatically unpause the DAG instead of throwing an error
return errors.Newf("Unable to trigger a new DAG run for %v because it is currently paused. You must unpause it first!", dagID)
}

request := c.apiClient.DAGRunApi.PostDagRun(c.ctx, dagID)
// The PostDagRun API requires the request to have a DAGRun initialized
request = request.DAGRun(*airflow.NewDAGRunWithDefaults())
_, _, err := request.Execute()
_, _, err = request.Execute()
return err
}
30 changes: 15 additions & 15 deletions src/golang/lib/airflow/sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,12 @@ func syncWorkflowDag(
vault vault.Vault,
DB database.Database,
) error {
txn, err := DB.BeginTx(ctx)
if err != nil {
return err
}
defer database.TxnRollbackIgnoreErr(ctx, txn)

// Read Airflow credentials from vault
authConf, err := auth.ReadConfigFromSecret(
ctx,
Expand All @@ -104,7 +110,7 @@ func syncWorkflowDag(
cli,
dag,
dagRepo,
DB,
txn,
)
if err != nil {
return err
Expand All @@ -123,7 +129,7 @@ func syncWorkflowDag(

// Default values to not have an order and not have a limit: Empty string for order_by, -1 for limit
// Set true for order_by order (desc/asc) because doesn't matter.
dagResults, err := dagResultRepo.GetByWorkflow(ctx, dag.WorkflowID, "", -1, true, DB)
dagResults, err := dagResultRepo.GetByWorkflow(ctx, dag.WorkflowID, "", -1, true, txn)
if err != nil {
return err
}
Expand Down Expand Up @@ -160,12 +166,16 @@ func syncWorkflowDag(
dagResultRepo,
operatorResultRepo,
artifactResultRepo,
DB,
txn,
); err != nil {
return err
}
}

if err := txn.Commit(ctx); err != nil {
return err
}

return nil
}

Expand All @@ -182,18 +192,12 @@ func syncWorkflowDagResult(
artifactResultRepo repos.ArtifactResult,
DB database.Database,
) error {
txn, err := DB.BeginTx(ctx)
if err != nil {
return err
}
defer database.TxnRollbackIgnoreErr(ctx, txn)

dagResult, err := createDAGResult(
ctx,
dag,
run,
dagResultRepo,
txn,
DB,
)
if err != nil {
return err
Expand Down Expand Up @@ -228,16 +232,12 @@ func syncWorkflowDagResult(
dagResult.ID,
operatorResultRepo,
artifactResultRepo,
txn,
DB,
); err != nil {
return err
}
}

if err := txn.Commit(ctx); err != nil {
return err
}

return nil
}

Expand Down
5 changes: 3 additions & 2 deletions src/python/aqueduct_executor/operators/airflow/dag.template
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ def invoke_task(spec, **kwargs):
elif spec_type == enums.JobType.LOAD:
spec = conn_spec.LoadSpec(**spec)
spec.metadata_path = "{}_{}".format(spec.metadata_path, dag_run_id)
spec.input_content_path = "{}_{}".format(spec.input_content_path, dag_run_id)
spec.input_metadata_path = "{}_{}".format(spec.input_metadata_path, dag_run_id)
spec.input_content_paths = ["{}_{}".format(p, dag_run_id) for p in spec.input_content_paths]
spec.input_metadata_paths = ["{}_{}".format(p, dag_run_id) for p in spec.input_metadata_paths]
conn_execute.run(spec)
elif spec_type == enums.JobType.PARAM:
spec = param_spec.ParamSpec(**spec)
Expand Down Expand Up @@ -111,6 +111,7 @@ with DAG(
schedule_interval={{ schedule }},
{% endif %}
catchup=False,
is_paused_upon_creation=False,
tags=['aqueduct', '{{ workflow_dag_id }}'],
) as dag:
# Constants to handle JSON serialization
Expand Down