diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 9c06d9eee..7ef4b607c 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -67,7 +67,7 @@ jobs: - name: Fetch the API key run: echo "API_KEY=$(aqueduct apikey)" >> $GITHUB_ENV - - name: Run the Integration Tests + - name: Run the SDK Integration Tests timeout-minutes: 10 working-directory: integration_tests/sdk env: @@ -75,6 +75,13 @@ jobs: INTEGRATION: aqueduct_demo run: pytest . -rP --publish -n 1 + - name: Run the Backend Integration Tests + timeout-minutes: 10 + working-directory: integration_tests/backend + env: + SERVER_ADDRESS: localhost:8080 + run: pytest . -rP -n 1 + - uses: actions/upload-artifact@v3 if: always() with: diff --git a/integration_tests/backend/README.md b/integration_tests/backend/README.md new file mode 100644 index 000000000..5512fc7cf --- /dev/null +++ b/integration_tests/backend/README.md @@ -0,0 +1,29 @@ +# Backend Integration Tests + +These tests are run against the Aqueduct backend to check the endpoints' reads and outputs are as expected. + +The `setup_class` sets up all the workflows which are read by each test. When all the tests in the suite are done, the workflows set up in `setup_class` are deleted in the `teardown_class`. + +## Creating Tests +The workflows ran in setup tests are all the Python files in the `setup/` folder. Each Python file is called in the format `{python_file} {api_key} {server_address}`. At the top, you can parse those arguments like so: +``` +import sys +api_key, server_address = sys.argv[1], sys.argv[2] +``` +After that, you can write a workflow as you would do as a typical user. +At the very end, the tests **require** you to print the flow id (e.g. `print(flow.id())`). This is parsed by the suite setup function and saved to a list of flow ids. At the end of testing, the teardown function will iterate through the flow ids and delete the associated workflows. + +## Usage + +Running all the tests in this repo: +`API_KEY= SERVER_ADDRESS= pytest . -rP` + +Running all the tests in a single file: +- ` pytest -rP` + +Running a specific test: +- ` pytest . -rP -k ''` + +Running tests in parallel, with concurrency 5: +- Install pytest-xdist +- ` pytest . -rP -n 5` diff --git a/integration_tests/backend/conftest.py b/integration_tests/backend/conftest.py new file mode 100644 index 000000000..a4e554261 --- /dev/null +++ b/integration_tests/backend/conftest.py @@ -0,0 +1,18 @@ +import os + +import pytest + +import aqueduct + +API_KEY_ENV_NAME = "API_KEY" +SERVER_ADDR_ENV_NAME = "SERVER_ADDRESS" + + +def pytest_configure(config): + pytest.api_key = os.getenv(API_KEY_ENV_NAME) + pytest.server_address = os.getenv(SERVER_ADDR_ENV_NAME) + + if pytest.api_key is None or pytest.server_address is None: + raise Exception( + "Test Setup Error: api_key and server_address must be set as environmental variables." + ) diff --git a/integration_tests/backend/setup/changing_saves.py b/integration_tests/backend/setup/changing_saves.py new file mode 100644 index 000000000..638851dee --- /dev/null +++ b/integration_tests/backend/setup/changing_saves.py @@ -0,0 +1,64 @@ +import sys + +api_key, server_address = sys.argv[1], sys.argv[2] + +### +# Workflow that loads a table from the `aqueduct_demo` then saves it to `table_1` in append mode. +# This save operator is then replaced by one that saves to `table_1` in replace mode. +# In the next deployment of this run, it saves to `table_1` in append mode. +# In the last deployment, it saves to `table_2` in append mode. +### + +import aqueduct + +name = "Test: Changing Saves" +client = aqueduct.Client(api_key, server_address) +integration = client.integration(name="aqueduct_demo") + +### + +table = integration.sql(query="SELECT * FROM wine;") + +table.save(integration.config(table="table_1", update_mode="append")) + +flow = client.publish_flow( + name=name, + artifacts=[table], +) + +### + +table = integration.sql(query="SELECT * FROM wine;") + +table.save(integration.config(table="table_1", update_mode="replace")) + +flow = client.publish_flow( + name=name, + artifacts=[table], +) + +### + +table = integration.sql(query="SELECT * FROM wine;") + +table.save(integration.config(table="table_1", update_mode="append")) + +flow = client.publish_flow( + name=name, + artifacts=[table], +) + +### + +table = integration.sql(query="SELECT * FROM wine;") + +table.save(integration.config(table="table_2", update_mode="append")) + +flow = client.publish_flow( + name=name, + artifacts=[table], +) + +### + +print(flow.id()) diff --git a/integration_tests/backend/test_backend.py b/integration_tests/backend/test_backend.py new file mode 100644 index 000000000..3d73431e0 --- /dev/null +++ b/integration_tests/backend/test_backend.py @@ -0,0 +1,76 @@ +import os +import subprocess +import sys +from pathlib import Path + +import pytest +import requests + +import aqueduct + + +class TestBackend: + GET_WORKFLOW_TABLES_TEMPLATE = "/api/workflow/%s/tables" + WORKFLOW_PATH = Path(__file__).parent / "setup" + + @classmethod + def setup_class(cls): + cls.client = aqueduct.Client(pytest.api_key, pytest.server_address) + cls.flows = {} + + workflow_files = [ + f + for f in os.listdir(cls.WORKFLOW_PATH) + if os.path.isfile(os.path.join(cls.WORKFLOW_PATH, f)) + ] + for workflow in workflow_files: + proc = subprocess.Popen( + [ + "python3", + os.path.join(cls.WORKFLOW_PATH, workflow), + pytest.api_key, + pytest.server_address, + ], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + out, err = proc.communicate() + out = out.decode("utf-8") + err = err.decode("utf-8") + if err: + raise Exception(f"Could not run workflow {workflow}.\n\n{err}") + else: + cls.flows[workflow] = out.strip().split()[-1] + + @classmethod + def teardown_class(cls): + for flow in cls.flows: + cls.client.delete_flow(cls.flows[flow]) + + @classmethod + def get_response_class(cls, endpoint, additional_headers={}): + headers = {"api-key": pytest.api_key} + headers.update(additional_headers) + url = cls.client._api_client.construct_full_url(endpoint) + r = requests.get(url, headers=headers) + return r + + def test_endpoint_getworkflowtables(self): + endpoint = self.GET_WORKFLOW_TABLES_TEMPLATE % self.flows["changing_saves.py"] + data = self.get_response_class(endpoint).json()["table_details"] + + assert len(data) == 3 + + # table_name, update_mode + data_set = set( + [ + ("table_1", "append"), + ("table_1", "replace"), + ("table_2", "append"), + ] + ) + assert set([(item["table_name"], item["update_mode"]) for item in data]) == data_set + + # Check all in same integration + assert len(set([item["integration_id"] for item in data])) == 1 + assert len(set([item["service"] for item in data])) == 1 diff --git a/scripts/install_local.py b/scripts/install_local.py index 627bbe975..b07583dfc 100644 --- a/scripts/install_local.py +++ b/scripts/install_local.py @@ -129,9 +129,13 @@ def execute_command(args, cwd=None): # directory /home/ec2-user/SageMaker exists. This is hacky but we couldn't find a better # solution at the moment. if isdir(join(os.sep, "home", "ec2-user", "SageMaker")): - shutil.copytree(join(cwd, "src", "ui", "app" ,"dist", "sagemaker"), ui_directory, dirs_exist_ok=True) + shutil.copytree( + join(cwd, "src", "ui", "app", "dist", "sagemaker"), ui_directory, dirs_exist_ok=True + ) else: - shutil.copytree(join(cwd, "src", "ui", "app" ,"dist", "default"), ui_directory, dirs_exist_ok=True) + shutil.copytree( + join(cwd, "src", "ui", "app", "dist", "default"), ui_directory, dirs_exist_ok=True + ) # Install the local SDK. if args.update_sdk: diff --git a/sdk/aqueduct/api_client.py b/sdk/aqueduct/api_client.py index 2efeeb909..06f0c72e3 100644 --- a/sdk/aqueduct/api_client.py +++ b/sdk/aqueduct/api_client.py @@ -81,6 +81,10 @@ def __init__(self, api_key: str, aqueduct_address: str): self.api_key = api_key self.aqueduct_address = aqueduct_address + # Clean URL + if self.aqueduct_address.endswith("/"): + self.aqueduct_address = self.aqueduct_address[:-1] + # If a dummy client is initialized, don't perform validation. if self.api_key == "" and self.aqueduct_address == "": Logger.logger.info( @@ -98,9 +102,16 @@ def __init__(self, api_key: str, aqueduct_address: str): else: self.use_https = self._test_connection_protocol(try_http=True, try_https=True) - def _construct_full_url(self, route_suffix: str, use_https: bool) -> str: + def construct_base_url(self, use_https: Optional[bool] = None) -> str: + if use_https is None: + use_https = self.use_https protocol_prefix = self.HTTPS_PREFIX if use_https else self.HTTP_PREFIX - return "%s%s%s" % (protocol_prefix, self.aqueduct_address, route_suffix) + return "%s%s" % (protocol_prefix, self.aqueduct_address) + + def construct_full_url(self, route_suffix: str, use_https: Optional[bool] = None) -> str: + if use_https is None: + use_https = self.use_https + return "%s%s" % (self.construct_base_url(use_https), route_suffix) def _test_connection_protocol(self, try_http: bool, try_https: bool) -> bool: """Returns whether the connection uses https. Raises an exception if unable to connect at all. @@ -111,7 +122,7 @@ def _test_connection_protocol(self, try_http: bool, try_https: bool) -> bool: if try_https: try: - url = self._construct_full_url(self.LIST_INTEGRATIONS_ROUTE, use_https=True) + url = self.construct_full_url(self.LIST_INTEGRATIONS_ROUTE, use_https=True) self._test_url(url) return True except Exception as e: @@ -121,7 +132,7 @@ def _test_connection_protocol(self, try_http: bool, try_https: bool) -> bool: if try_http: try: - url = self._construct_full_url(self.LIST_INTEGRATIONS_ROUTE, use_https=False) + url = self.construct_full_url(self.LIST_INTEGRATIONS_ROUTE, use_https=False) self._test_url(url) return False except Exception as e: @@ -147,7 +158,7 @@ def url_prefix(self) -> str: return self.HTTPS_PREFIX if self.use_https else self.HTTP_PREFIX def list_integrations(self) -> Dict[str, IntegrationInfo]: - url = self._construct_full_url(self.LIST_INTEGRATIONS_ROUTE, self.use_https) + url = self.construct_full_url(self.LIST_INTEGRATIONS_ROUTE) headers = utils.generate_auth_headers(self.api_key) resp = requests.get(url, headers=headers) utils.raise_errors(resp) @@ -162,14 +173,14 @@ def list_integrations(self) -> Dict[str, IntegrationInfo]: } def list_github_repos(self) -> List[str]: - url = self._construct_full_url(self.LIST_GITHUB_REPO_ROUTE, self.use_https) + url = self.construct_full_url(self.LIST_GITHUB_REPO_ROUTE) headers = utils.generate_auth_headers(self.api_key) resp = requests.get(url, headers=headers) return [x for x in resp.json()["repos"]] def list_github_branches(self, repo_url: str) -> List[str]: - url = self._construct_full_url(self.LIST_GITHUB_BRANCH_ROUTE, self.use_https) + url = self.construct_full_url(self.LIST_GITHUB_BRANCH_ROUTE) headers = utils.generate_auth_headers(self.api_key) headers["github-repo"] = repo_url @@ -177,7 +188,7 @@ def list_github_branches(self, repo_url: str) -> List[str]: return [x for x in resp.json()["branches"]] def list_tables(self, limit: int) -> List[Tuple[str, str]]: - url = self._construct_full_url(self.LIST_TABLES_ROUTE, self.use_https) + url = self.construct_full_url(self.LIST_TABLES_ROUTE) headers = utils.generate_auth_headers(self.api_key) headers["limit"] = str(limit) resp = requests.get(url, headers=headers) @@ -212,7 +223,7 @@ def preview( if file: files[str(op.id)] = io.BytesIO(file) - url = self._construct_full_url(self.PREVIEW_ROUTE, self.use_https) + url = self.construct_full_url(self.PREVIEW_ROUTE) resp = requests.post(url, headers=headers, data=body, files=files) utils.raise_errors(resp) @@ -241,7 +252,7 @@ def register_workflow( if file: files[str(op.id)] = io.BytesIO(file) - url = self._construct_full_url(self.REGISTER_WORKFLOW_ROUTE, self.use_https) + url = self.construct_full_url(self.REGISTER_WORKFLOW_ROUTE) resp = requests.post(url, headers=headers, data=body, files=files) utils.raise_errors(resp) @@ -253,9 +264,7 @@ def refresh_workflow( serialized_params: Optional[str] = None, ) -> None: headers = utils.generate_auth_headers(self.api_key) - url = self._construct_full_url( - self.REFRESH_WORKFLOW_ROUTE_TEMPLATE % flow_id, self.use_https - ) + url = self.construct_full_url(self.REFRESH_WORKFLOW_ROUTE_TEMPLATE % flow_id) body = {} if serialized_params is not None: @@ -266,15 +275,13 @@ def refresh_workflow( def delete_workflow(self, flow_id: str) -> None: headers = utils.generate_auth_headers(self.api_key) - url = self._construct_full_url( - self.DELETE_WORKFLOW_ROUTE_TEMPLATE % flow_id, self.use_https - ) + url = self.construct_full_url(self.DELETE_WORKFLOW_ROUTE_TEMPLATE % flow_id) response = requests.post(url, headers=headers) utils.raise_errors(response) def get_workflow(self, flow_id: str) -> GetWorkflowResponse: headers = utils.generate_auth_headers(self.api_key) - url = self._construct_full_url(self.GET_WORKFLOW_ROUTE_TEMPLATE % flow_id, self.use_https) + url = self.construct_full_url(self.GET_WORKFLOW_ROUTE_TEMPLATE % flow_id) resp = requests.get(url, headers=headers) utils.raise_errors(resp) workflow_response = GetWorkflowResponse(**resp.json()) @@ -282,7 +289,7 @@ def get_workflow(self, flow_id: str) -> GetWorkflowResponse: def list_workflows(self) -> List[ListWorkflowResponseEntry]: headers = utils.generate_auth_headers(self.api_key) - url = self._construct_full_url(self.LIST_WORKFLOWS_ROUTE, self.use_https) + url = self.construct_full_url(self.LIST_WORKFLOWS_ROUTE) response = requests.get(url, headers=headers) utils.raise_errors(response) @@ -291,8 +298,8 @@ def list_workflows(self) -> List[ListWorkflowResponseEntry]: def get_artifact_result_data(self, dag_result_id: str, artifact_id: str) -> str: """Returns an empty string if the artifact failed to be computed.""" headers = utils.generate_auth_headers(self.api_key) - url = self._construct_full_url( - self.GET_ARTIFACT_RESULT_TEMPLATE % (dag_result_id, artifact_id), self.use_https + url = self.construct_full_url( + self.GET_ARTIFACT_RESULT_TEMPLATE % (dag_result_id, artifact_id) ) resp = requests.get(url, headers=headers) utils.raise_errors(resp) @@ -318,7 +325,7 @@ def get_node_positions( Two mappings of UUIDs to positions, structured as a dictionary with the keys "x" and "y". The first mapping is for operators and the second is for artifacts. """ - url = self._construct_full_url(self.NODE_POSITION_ROUTE, self.use_https) + url = self.construct_full_url(self.NODE_POSITION_ROUTE) headers = { **utils.generate_auth_headers(self.api_key), } @@ -332,8 +339,6 @@ def get_node_positions( def export_serialized_function(self, operator: Operator) -> bytes: headers = utils.generate_auth_headers(self.api_key) - operator_url = self._construct_full_url( - self.EXPORT_FUNCTION_ROUTE % str(operator.id), self.use_https - ) + operator_url = self.construct_full_url(self.EXPORT_FUNCTION_ROUTE % str(operator.id)) operator_resp = requests.get(operator_url, headers=headers) return operator_resp.content diff --git a/sdk/aqueduct/aqueduct_client.py b/sdk/aqueduct/aqueduct_client.py index 55b436923..e322081ea 100644 --- a/sdk/aqueduct/aqueduct_client.py +++ b/sdk/aqueduct/aqueduct_client.py @@ -337,8 +337,7 @@ def publish_flow( flow_id = self._api_client.register_workflow(dag).id url = generate_ui_url( - self._api_client.url_prefix(), - self._api_client.aqueduct_address, + self._api_client.construct_base_url(), str(flow_id), ) print("Url: ", url) diff --git a/sdk/aqueduct/flow.py b/sdk/aqueduct/flow.py index 25da2e2ce..3e7fa0aec 100644 --- a/sdk/aqueduct/flow.py +++ b/sdk/aqueduct/flow.py @@ -12,7 +12,7 @@ from .logger import Logger from .operators import OperatorSpec, ParamSpec from .responses import WorkflowDagResponse, WorkflowDagResultResponse -from .utils import generate_ui_url, parse_user_supplied_id, format_header_for_print +from .utils import format_header_for_print, generate_ui_url, parse_user_supplied_id class Flow: @@ -154,9 +154,7 @@ def describe(self) -> None: assert latest_metadata.schedule is not None, "A flow must have a schedule." assert latest_metadata.retention_policy is not None, "A flow must have a retention policy." - url = generate_ui_url( - self._api_client.url_prefix(), self._api_client.aqueduct_address, self._id - ) + url = generate_ui_url(self._api_client.construct_base_url(), self._id) print( textwrap.dedent( diff --git a/sdk/aqueduct/flow_run.py b/sdk/aqueduct/flow_run.py index c216e3996..d41604928 100644 --- a/sdk/aqueduct/flow_run.py +++ b/sdk/aqueduct/flow_run.py @@ -1,20 +1,20 @@ import textwrap import uuid from textwrap import wrap -from typing import Dict, Any, Mapping, Optional, Union, List +from typing import Any, Dict, List, Mapping, Optional, Union import plotly.graph_objects as go from aqueduct.api_client import APIClient from aqueduct.artifact import Artifact, get_artifact_type from aqueduct.check_artifact import CheckArtifact from aqueduct.dag import DAG -from aqueduct.enums import OperatorType, DisplayNodeType, ExecutionStatus, ArtifactType +from aqueduct.enums import ArtifactType, DisplayNodeType, ExecutionStatus, OperatorType from aqueduct.error import InternalAqueductError from aqueduct.metric_artifact import MetricArtifact from aqueduct.operators import Operator from aqueduct.param_artifact import ParamArtifact from aqueduct.table_artifact import TableArtifact -from aqueduct.utils import generate_ui_url, human_readable_timestamp, format_header_for_print +from aqueduct.utils import format_header_for_print, generate_ui_url, human_readable_timestamp class FlowRun: @@ -51,8 +51,7 @@ def describe(self) -> None: """Prints out a human-readable description of the flow run.""" url = generate_ui_url( - self._api_client.url_prefix(), - self._api_client.aqueduct_address, + self._api_client.construct_base_url(), self._flow_id, self._id, ) diff --git a/sdk/aqueduct/utils.py b/sdk/aqueduct/utils.py index 01adb0643..44d8911f1 100644 --- a/sdk/aqueduct/utils.py +++ b/sdk/aqueduct/utils.py @@ -57,19 +57,17 @@ def generate_uuid() -> uuid.UUID: def generate_ui_url( - url_prefix: str, aqueduct_address: str, workflow_id: str, result_id: Optional[str] = None + aqueduct_base_address: str, workflow_id: str, result_id: Optional[str] = None ) -> str: if result_id: - url = "%s%s%s%s" % ( - url_prefix, - aqueduct_address, + url = "%s%s%s" % ( + aqueduct_base_address, WORKFLOW_UI_ROUTE_TEMPLATE % workflow_id, WORKFLOW_RUN_UI_ROUTE_TEMPLATE % result_id, ) else: - url = "%s%s%s" % ( - url_prefix, - aqueduct_address, + url = "%s%s" % ( + aqueduct_base_address, WORKFLOW_UI_ROUTE_TEMPLATE % workflow_id, ) return url diff --git a/src/golang/cmd/server/handler/get_workflow_tables.go b/src/golang/cmd/server/handler/get_workflow_tables.go new file mode 100644 index 000000000..1036a0c0e --- /dev/null +++ b/src/golang/cmd/server/handler/get_workflow_tables.go @@ -0,0 +1,94 @@ +package handler + +import ( + "context" + "net/http" + + "github.com/aqueducthq/aqueduct/cmd/server/routes" + "github.com/aqueducthq/aqueduct/lib/collections/operator" + "github.com/aqueducthq/aqueduct/lib/collections/workflow" + aq_context "github.com/aqueducthq/aqueduct/lib/context" + "github.com/aqueducthq/aqueduct/lib/database" + "github.com/dropbox/godropbox/errors" + "github.com/go-chi/chi" + "github.com/google/uuid" +) + +// Route: /workflow/{workflowId}/tables +// Method: GET +// Params: +// `workflowId`: ID for `workflow` object +// Request: +// Headers: +// `api-key`: user's API Key +// Response: +// Body: +// all tables for the given `workflowId` + +type getWorkflowTablesArgs struct { + *aq_context.AqContext + workflowId uuid.UUID +} + +type getWorkflowTablesResponse struct { + LoadDetails []operator.GetDistinctLoadOperatorsByWorkflowIdResponse `json:"table_details"` +} + +type GetWorkflowTablesHandler struct { + GetHandler + + Database database.Database + OperatorReader operator.Reader + WorkflowReader workflow.Reader +} + +func (*GetWorkflowTablesHandler) Name() string { + return "GetWorkflowTables" +} + +func (h *GetWorkflowTablesHandler) Prepare(r *http.Request) (interface{}, int, error) { + aqContext, statusCode, err := aq_context.ParseAqContext(r.Context()) + if err != nil { + return nil, statusCode, err + } + + workflowIdStr := chi.URLParam(r, routes.WorkflowIdUrlParam) + workflowId, err := uuid.Parse(workflowIdStr) + if err != nil { + return nil, http.StatusBadRequest, errors.Wrap(err, "Malformed workflow ID.") + } + + ok, err := h.WorkflowReader.ValidateWorkflowOwnership( + r.Context(), + workflowId, + aqContext.OrganizationId, + h.Database, + ) + if err != nil { + return nil, http.StatusInternalServerError, errors.Wrap(err, "Unexpected error during workflow ownership validation.") + } + if !ok { + return nil, http.StatusBadRequest, errors.Wrap(err, "The organization does not own this workflow.") + } + + return &getWorkflowTablesArgs{ + AqContext: aqContext, + workflowId: workflowId, + }, http.StatusOK, nil +} + +func (h *GetWorkflowTablesHandler) Perform(ctx context.Context, interfaceArgs interface{}) (interface{}, int, error) { + args := interfaceArgs.(*getWorkflowTablesArgs) + + emptyResp := getWorkflowTablesResponse{} + + // Get all specs for the workflow. + operatorList, err := h.OperatorReader.GetDistinctLoadOperatorsByWorkflowId(ctx, args.workflowId, h.Database) + if err != nil { + return emptyResp, http.StatusInternalServerError, errors.Wrap(err, "Unexpected error occurred when retrieving workflow.") + } + + return getWorkflowTablesResponse{ + LoadDetails: operatorList, + }, http.StatusOK, nil +} diff --git a/src/golang/cmd/server/routes/routes.go b/src/golang/cmd/server/routes/routes.go index a424adc96..7c15b933c 100644 --- a/src/golang/cmd/server/routes/routes.go +++ b/src/golang/cmd/server/routes/routes.go @@ -28,12 +28,13 @@ const ( GetUserProfileRoute = "/api/user" - ListWorkflowsRoute = "/api/workflows" - RegisterWorkflowRoute = "/api/workflow/register" - GetWorkflowRoute = "/api/workflow/{workflowId}" - DeleteWorkflowRoute = "/api/workflow/{workflowId}/delete" - EditWorkflowRoute = "/api/workflow/{workflowId}/edit" - RefreshWorkflowRoute = "/api/workflow/{workflowId}/refresh" - UnwatchWorkflowRoute = "/api/workflow/{workflowId}/unwatch" - WatchWorkflowRoute = "/api/workflow/{workflowId}/watch" + ListWorkflowsRoute = "/api/workflows" + RegisterWorkflowRoute = "/api/workflow/register" + GetWorkflowRoute = "/api/workflow/{workflowId}" + GetWorkflowTablesRoute = "/api/workflow/{workflowId}/tables" + DeleteWorkflowRoute = "/api/workflow/{workflowId}/delete" + EditWorkflowRoute = "/api/workflow/{workflowId}/edit" + RefreshWorkflowRoute = "/api/workflow/{workflowId}/refresh" + UnwatchWorkflowRoute = "/api/workflow/{workflowId}/unwatch" + WatchWorkflowRoute = "/api/workflow/{workflowId}/watch" ) diff --git a/src/golang/cmd/server/server/handlers.go b/src/golang/cmd/server/server/handlers.go index 65aa3b011..46d3d5f7a 100644 --- a/src/golang/cmd/server/server/handlers.go +++ b/src/golang/cmd/server/server/handlers.go @@ -70,6 +70,11 @@ func (s *AqServer) Handlers() map[string]handler.Handler { OperatorResultReader: s.OperatorResultReader, }, routes.GetUserProfileRoute: &handler.GetUserProfileHandler{}, + routes.GetWorkflowTablesRoute: &handler.GetWorkflowTablesHandler{ + Database: s.Database, + OperatorReader: s.OperatorReader, + WorkflowReader: s.WorkflowReader, + }, routes.GetWorkflowRoute: &handler.GetWorkflowHandler{ Database: s.Database, ArtifactReader: s.ArtifactReader, diff --git a/src/golang/internal/server/routes/routes.go b/src/golang/internal/server/routes/routes.go index e38b06891..2fc246c38 100644 --- a/src/golang/internal/server/routes/routes.go +++ b/src/golang/internal/server/routes/routes.go @@ -28,12 +28,13 @@ const ( GetUserProfileRoute = "/user" - ListWorkflowsRoute = "/workflows" - RegisterWorkflowRoute = "/workflow/register" - GetWorkflowRoute = "/workflow/{workflowId}" - DeleteWorkflowRoute = "/workflow/{workflowId}/delete" - EditWorkflowRoute = "/workflow/{workflowId}/edit" - RefreshWorkflowRoute = "/workflow/{workflowId}/refresh" - UnwatchWorkflowRoute = "/workflow/{workflowId}/unwatch" - WatchWorkflowRoute = "/workflow/{workflowId}/watch" + ListWorkflowsRoute = "/workflows" + RegisterWorkflowRoute = "/workflow/register" + GetWorkflowRoute = "/workflow/{workflowId}" + GetWorkflowTablesRoute = "/workflow/{workflowId}/tables" + DeleteWorkflowRoute = "/workflow/{workflowId}/delete" + EditWorkflowRoute = "/workflow/{workflowId}/edit" + RefreshWorkflowRoute = "/workflow/{workflowId}/refresh" + UnwatchWorkflowRoute = "/workflow/{workflowId}/unwatch" + WatchWorkflowRoute = "/workflow/{workflowId}/watch" ) diff --git a/src/golang/lib/collections/operator/operator.go b/src/golang/lib/collections/operator/operator.go index 46a9f7bf2..0018677ef 100644 --- a/src/golang/lib/collections/operator/operator.go +++ b/src/golang/lib/collections/operator/operator.go @@ -27,6 +27,11 @@ type Reader interface { workflowDagId uuid.UUID, db database.Database, ) ([]DBOperator, error) + GetDistinctLoadOperatorsByWorkflowId( + ctx context.Context, + workflowId uuid.UUID, + db database.Database, + ) ([]GetDistinctLoadOperatorsByWorkflowIdResponse, error) ValidateOperatorOwnership( ctx context.Context, organizationId string, diff --git a/src/golang/lib/collections/operator/postgres.go b/src/golang/lib/collections/operator/postgres.go index 7254e6b66..d006da15a 100644 --- a/src/golang/lib/collections/operator/postgres.go +++ b/src/golang/lib/collections/operator/postgres.go @@ -1,5 +1,12 @@ package operator +import ( + "context" + + "github.com/aqueducthq/aqueduct/lib/database" + "github.com/google/uuid" +) + type postgresReaderImpl struct { standardReaderImpl } @@ -15,3 +22,36 @@ func newPostgresReader() Reader { func newPostgresWriter() Writer { return &postgresWriterImpl{standardWriterImpl{}} } + +func (r *postgresReaderImpl) GetDistinctLoadOperatorsByWorkflowId( + ctx context.Context, + workflowId uuid.UUID, + db database.Database, +) ([]GetDistinctLoadOperatorsByWorkflowIdResponse, error) { + query := ` + SELECT DISTINCT + name, + json_extract_path_text(spec, 'load', 'integration_id') AS integration_id, + json_extract_path_text(spec, 'load', 'service') AS service, + json_extract_path_text(spec, 'load', 'parameters', 'table') AS table_name, + json_extract_path_text(spec, 'load', 'parameters', 'update_mode') AS update_mode + FROM operator WHERE ( + json_extract_path_text(spec, 'type') = 'load' AND + EXISTS ( + SELECT 1 + FROM + workflow_dag_edge, workflow_dag + WHERE + ( + workflow_dag_edge.from_id = operator.id OR + workflow_dag_edge.to_id = operator.id + ) AND + workflow_dag_edge.workflow_dag_id = workflow_dag.id AND + workflow_dag.workflow_id = $1 + ) + );` + + var workflowSpecs []GetDistinctLoadOperatorsByWorkflowIdResponse + err := db.Query(ctx, &workflowSpecs, query, workflowId) + return workflowSpecs, err +} diff --git a/src/golang/lib/collections/operator/sqlite.go b/src/golang/lib/collections/operator/sqlite.go index c9743908e..d54cdf9e8 100644 --- a/src/golang/lib/collections/operator/sqlite.go +++ b/src/golang/lib/collections/operator/sqlite.go @@ -5,6 +5,7 @@ import ( "github.com/aqueducthq/aqueduct/lib/collections/utils" "github.com/aqueducthq/aqueduct/lib/database" + "github.com/google/uuid" ) type sqliteReaderImpl struct { @@ -44,3 +45,38 @@ func (w *sqliteWriterImpl) CreateOperator( err = db.Query(ctx, &operator, insertOperatorStmt, args...) return &operator, err } + +func (r *sqliteReaderImpl) GetDistinctLoadOperatorsByWorkflowId( + ctx context.Context, + workflowId uuid.UUID, + db database.Database, +) ([]GetDistinctLoadOperatorsByWorkflowIdResponse, error) { + query := ` + SELECT DISTINCT + name, + json_extract(spec, '$.load.integration_id') AS integration_id, + json_extract(spec, '$.load.service') AS service, + json_extract(spec, '$.load.parameters.table') AS table_name, + json_extract(spec, '$.load.parameters.update_mode') AS update_mode + FROM + operator + WHERE ( + json_extract(spec, '$.type')='load' AND + EXISTS ( + SELECT 1 + FROM + workflow_dag_edge, workflow_dag + WHERE + ( + workflow_dag_edge.from_id = operator.id OR + workflow_dag_edge.to_id = operator.id + ) AND + workflow_dag_edge.workflow_dag_id = workflow_dag.id AND + workflow_dag.workflow_id = $1 + ) + );` + + var workflowSpecs []GetDistinctLoadOperatorsByWorkflowIdResponse + err := db.Query(ctx, &workflowSpecs, query, workflowId) + return workflowSpecs, err +} diff --git a/src/golang/lib/collections/operator/standard.go b/src/golang/lib/collections/operator/standard.go index b4d615c4d..15c28cdb4 100644 --- a/src/golang/lib/collections/operator/standard.go +++ b/src/golang/lib/collections/operator/standard.go @@ -4,6 +4,7 @@ import ( "context" "fmt" + "github.com/aqueducthq/aqueduct/lib/collections/integration" "github.com/aqueducthq/aqueduct/lib/collections/utils" "github.com/aqueducthq/aqueduct/lib/collections/workflow_dag_edge" "github.com/aqueducthq/aqueduct/lib/database" @@ -16,6 +17,14 @@ type standardReaderImpl struct{} type standardWriterImpl struct{} +type GetDistinctLoadOperatorsByWorkflowIdResponse struct { + Name string `db:"name" json:"name"` + Integration_id uuid.UUID `db:"integration_id" json:"integration_id"` + Service integration.Service `db:"service" json:"service"` + Table_name string `db:"table_name" json:"table_name"` + Update_mode string `db:"update_mode" json:"update_mode"` +} + func (w *standardWriterImpl) CreateOperator( ctx context.Context, name string,