From 78960c5a6ca60407914801a46963b9d9d0bd7483 Mon Sep 17 00:00:00 2001 From: Kenneth Xu Date: Mon, 13 Jun 2022 12:14:44 -0700 Subject: [PATCH 1/2] ready --- integration_tests/sdk/param_test.py | 23 ++++++++++++++ sdk/aqueduct/aqueduct_client.py | 12 ++------ sdk/aqueduct/check_artifact.py | 11 +++++-- sdk/aqueduct/dag.py | 47 +++++++++++++++++++++++++++-- sdk/aqueduct/generic_artifact.py | 4 +-- sdk/aqueduct/metric_artifact.py | 8 +++-- sdk/aqueduct/operators.py | 16 ++++++++-- sdk/aqueduct/param_artifact.py | 10 ++++-- sdk/aqueduct/table_artifact.py | 10 ++++-- 9 files changed, 115 insertions(+), 26 deletions(-) diff --git a/integration_tests/sdk/param_test.py b/integration_tests/sdk/param_test.py index 42e6ca54a..ebeacf535 100644 --- a/integration_tests/sdk/param_test.py +++ b/integration_tests/sdk/param_test.py @@ -2,6 +2,7 @@ import pytest +from aqueduct.error import InvalidUserArgumentException from constants import SENTIMENT_SQL_QUERY from utils import get_integration_name, run_flow_test from aqueduct import metric, op @@ -51,6 +52,28 @@ def test_basic_param_creation(client): assert kv_df.get().equals(pd.DataFrame(data=kv)) +def test_non_jsonable_parameter(sp_client): + with pytest.raises(InvalidUserArgumentException): + _ = sp_client.create_param(name="bad param", default=b"cant serialize me") + + param = sp_client.create_param(name="number", default=8) + param_doubled = double_number_input(param) + with pytest.raises(InvalidUserArgumentException): + _ = param_doubled.get(parameters={"number": b"cant serialize me"}) + + +def test_get_with_custom_parameter(sp_client): + param = sp_client.create_param(name="number", default=8) + assert param.get() == 8 + + param_doubled = double_number_input(param) + assert param_doubled.get(parameters={"number": 20}) == 40 + assert param_doubled.get() == 2 * 8 + + with pytest.raises(InvalidUserArgumentException): + param_doubled.get(parameters={"non-existant param": 10}) + + @op def append_row_to_df(df, row): """`row` is a list of values to append to the input dataframe.""" diff --git a/sdk/aqueduct/aqueduct_client.py b/sdk/aqueduct/aqueduct_client.py index 54b7f86b8..c2f08b177 100644 --- a/sdk/aqueduct/aqueduct_client.py +++ b/sdk/aqueduct/aqueduct_client.py @@ -22,7 +22,7 @@ from .integrations.salesforce_integration import SalesforceIntegration from .integrations.google_sheets_integration import GoogleSheetsIntegration from .integrations.s3_integration import S3Integration -from .operators import Operator, ParamSpec, OperatorSpec +from .operators import Operator, ParamSpec, OperatorSpec, serialize_parameter_value from .param_artifact import ParamArtifact from .utils import ( schedule_from_cron_string, @@ -128,13 +128,7 @@ def create_param(self, name: str, default: Any, description: str = "") -> ParamA if default is None: raise InvalidUserArgumentException("Parameter default value cannot be None.") - # Check that the supplied value is JSON-able. - try: - serialized_default = str(json.dumps(default)) - except Exception as e: - raise InvalidUserArgumentException( - "Provided parameter must be able to be converted into a JSON object: %s" % str(e) - ) + val = serialize_parameter_value(name, default) operator_id = generate_uuid() output_artifact_id = generate_uuid() @@ -146,7 +140,7 @@ def create_param(self, name: str, default: Any, description: str = "") -> ParamA id=operator_id, name=name, description=description, - spec=OperatorSpec(param=ParamSpec(val=serialized_default)), + spec=OperatorSpec(param=ParamSpec(val=val)), inputs=[], outputs=[output_artifact_id], ), diff --git a/sdk/aqueduct/check_artifact.py b/sdk/aqueduct/check_artifact.py index bb22c9e59..a3e2c02eb 100644 --- a/sdk/aqueduct/check_artifact.py +++ b/sdk/aqueduct/check_artifact.py @@ -2,10 +2,12 @@ import json import uuid +from typing import Optional, Dict, Any + from aqueduct.utils import get_description_for_check from aqueduct.api_client import APIClient -from aqueduct.dag import DAG, apply_deltas_to_dag, SubgraphDAGDelta +from aqueduct.dag import DAG, apply_deltas_to_dag, SubgraphDAGDelta, UpdateParametersDelta from aqueduct.error import AqueductError from aqueduct.generic_artifact import Artifact @@ -40,7 +42,7 @@ def __init__( self._dag = dag self._artifact_id = artifact_id - def get(self) -> bool: + def get(self, parameters: Optional[Dict[str, Any]] = None) -> bool: """Materializes a CheckArtifact into a boolean. Returns: @@ -58,7 +60,10 @@ def get(self) -> bool: SubgraphDAGDelta( artifact_ids=[self._artifact_id], include_load_operators=False, - ) + ), + UpdateParametersDelta( + parameters=parameters, + ), ], make_copy=True, ) diff --git a/sdk/aqueduct/dag.py b/sdk/aqueduct/dag.py index 373ede703..40e7a479f 100644 --- a/sdk/aqueduct/dag.py +++ b/sdk/aqueduct/dag.py @@ -1,6 +1,6 @@ import copy import uuid -from typing import List, Optional, Dict +from typing import List, Optional, Dict, Any from abc import ABC, abstractmethod from pydantic import BaseModel @@ -8,11 +8,12 @@ InternalAqueductError, InvalidUserActionException, ArtifactNotFoundException, + InvalidUserArgumentException, ) from aqueduct.artifact import Artifact from aqueduct.enums import OperatorType, TriggerType -from aqueduct.operators import Operator, get_operator_type +from aqueduct.operators import Operator, get_operator_type, serialize_parameter_value class Schedule(BaseModel): @@ -190,7 +191,8 @@ def list_artifacts( return [artifact for artifact in self.artifacts.values()] - # DAG WRITES + ######################## DAG WRITES ############################# + def add_operator(self, op: Operator) -> None: self.add_operators([op]) @@ -203,6 +205,11 @@ def add_artifacts(self, artifacts: List[Artifact]) -> None: for artifact in artifacts: self.artifacts[str(artifact.id)] = artifact + def update_operator(self, op: Operator) -> None: + """Blind replace of an operator in the dag.""" + self.operators[str(op.id)] = op + self.operator_by_name[op.name] = op + def remove_operator( self, operator_id: uuid.UUID, @@ -430,6 +437,40 @@ def apply(self, dag: DAG) -> None: ) +class UpdateParametersDelta(DAGDelta): + """Updates the values of the given parameters in the DAG to the given values. No-ops if no parameters provided.""" + + def __init__( + self, + parameters: Optional[Dict[str, Any]], + ): + self.parameters = parameters + + def apply(self, dag: DAG) -> None: + if self.parameters is None: + return + + for param_name, new_val in self.parameters.items(): + param_op = dag.get_operator(with_name=param_name) + if param_op is None: + raise InvalidUserArgumentException( + "Parameter %s cannot be found, or is not utilized in the current computation." + % param_name + ) + if get_operator_type(param_op) != OperatorType.PARAM: + raise InvalidUserArgumentException( + "Parameter %s must refer to a parameter, but instead refers to a: %s" + % (param_name, get_operator_type(param_op)) + ) + + # Update the parameter value and update the dag. + assert param_op.spec.param # for mypy + param_op.spec.param.val = serialize_parameter_value( + param_name, self.parameters[param_name] + ) + dag.update_operator(param_op) + + def apply_deltas_to_dag(dag: DAG, deltas: List[DAGDelta], make_copy: bool = False) -> DAG: if make_copy: dag = copy.deepcopy(dag) diff --git a/sdk/aqueduct/generic_artifact.py b/sdk/aqueduct/generic_artifact.py index 8eacab6de..7b2a5ecf4 100644 --- a/sdk/aqueduct/generic_artifact.py +++ b/sdk/aqueduct/generic_artifact.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod import json -from typing import Any, Dict +from typing import Any, Dict, Optional import uuid from aqueduct.dag import DAG @@ -34,5 +34,5 @@ def describe(self) -> None: pass @abstractmethod - def get(self) -> Any: + def get(self, parameters: Optional[Dict[str, Any]] = None) -> Any: pass diff --git a/sdk/aqueduct/metric_artifact.py b/sdk/aqueduct/metric_artifact.py index c93be23e1..e72bf6821 100644 --- a/sdk/aqueduct/metric_artifact.py +++ b/sdk/aqueduct/metric_artifact.py @@ -12,6 +12,7 @@ DAG, SubgraphDAGDelta, RemoveCheckOperatorDelta, + UpdateParametersDelta, ) from aqueduct.error import AqueductError @@ -56,7 +57,7 @@ def __init__( self._dag = dag self._artifact_id = artifact_id - def get(self) -> float: + def get(self, parameters: Optional[Dict[str, Any]] = None) -> float: """Materializes a MetricArtifact into its immediate float value. Returns: @@ -74,7 +75,10 @@ def get(self) -> float: SubgraphDAGDelta( artifact_ids=[self._artifact_id], include_load_operators=False, - ) + ), + UpdateParametersDelta( + parameters=parameters, + ), ], make_copy=True, ) diff --git a/sdk/aqueduct/operators.py b/sdk/aqueduct/operators.py index 573d73b9d..c691372db 100644 --- a/sdk/aqueduct/operators.py +++ b/sdk/aqueduct/operators.py @@ -1,4 +1,5 @@ -from typing import List, Optional, Union +import json +from typing import List, Optional, Union, Any import uuid from pydantic import BaseModel @@ -15,7 +16,7 @@ LoadUpdateMode, CheckSeverity, ) -from aqueduct.error import AqueductError +from aqueduct.error import AqueductError, InvalidUserArgumentException from aqueduct.integrations.integration import IntegrationInfo @@ -186,3 +187,14 @@ def get_operator_type(operator: Operator) -> OperatorType: return OperatorType.PARAM else: raise AqueductError("Invalid operator type") + + +def serialize_parameter_value(name: str, val: Any) -> str: + """A parameter must be JSON serializable.""" + try: + return str(json.dumps(val)) + except Exception as e: + raise InvalidUserArgumentException( + "Provided parameter %s must be able to be converted into a JSON object: %s" + % (name, str(e)) + ) diff --git a/sdk/aqueduct/param_artifact.py b/sdk/aqueduct/param_artifact.py index 94ebd4b1e..905b86c2c 100644 --- a/sdk/aqueduct/param_artifact.py +++ b/sdk/aqueduct/param_artifact.py @@ -1,9 +1,10 @@ import json import uuid -from typing import Any +from typing import Any, Dict, Optional from aqueduct.api_client import APIClient from aqueduct.dag import DAG +from aqueduct.error import InvalidUserArgumentException from aqueduct.generic_artifact import Artifact @@ -20,7 +21,12 @@ def __init__( self._dag = dag self._artifact_id = artifact_id - def get(self) -> Any: + def get(self, parameters: Optional[Dict[str, Any]] = None) -> Any: + if parameters is not None: + raise InvalidUserArgumentException( + "Parameters cannot be supplied to parameter artifacts." + ) + _ = self._dag.must_get_artifact(self._artifact_id) param_op = self._dag.must_get_operator(with_output_artifact_id=self._artifact_id) assert param_op.spec.param is not None, "Artifact is not a parameter." diff --git a/sdk/aqueduct/table_artifact.py b/sdk/aqueduct/table_artifact.py index 3746b2d4d..d4e969419 100644 --- a/sdk/aqueduct/table_artifact.py +++ b/sdk/aqueduct/table_artifact.py @@ -1,7 +1,7 @@ from __future__ import annotations import json -from typing import Callable, Any +from typing import Callable, Any, Optional, Dict import uuid import pandas as pd @@ -15,6 +15,7 @@ AddOrReplaceOperatorDelta, SubgraphDAGDelta, RemoveCheckOperatorDelta, + UpdateParametersDelta, ) from aqueduct.enums import OperatorType, FunctionType, FunctionGranularity from aqueduct.error import ( @@ -73,7 +74,7 @@ def __init__( self._dag = dag self._artifact_id = artifact_id - def get(self) -> pd.DataFrame: + def get(self, parameters: Optional[Dict[str, Any]] = None) -> pd.DataFrame: """Materializes TableArtifact into an actual dataframe. Returns: @@ -91,7 +92,10 @@ def get(self) -> pd.DataFrame: SubgraphDAGDelta( artifact_ids=[self._artifact_id], include_load_operators=False, - ) + ), + UpdateParametersDelta( + parameters=parameters, + ), ], make_copy=True, ) From 9b3b865a4869647fedcf5f9b6cef7b101861d237 Mon Sep 17 00:00:00 2001 From: Kenneth Xu Date: Mon, 13 Jun 2022 12:20:36 -0700 Subject: [PATCH 2/2] done and tested --- integration_tests/sdk/param_test.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/integration_tests/sdk/param_test.py b/integration_tests/sdk/param_test.py index ebeacf535..80e07f4c3 100644 --- a/integration_tests/sdk/param_test.py +++ b/integration_tests/sdk/param_test.py @@ -52,18 +52,18 @@ def test_basic_param_creation(client): assert kv_df.get().equals(pd.DataFrame(data=kv)) -def test_non_jsonable_parameter(sp_client): +def test_non_jsonable_parameter(client): with pytest.raises(InvalidUserArgumentException): - _ = sp_client.create_param(name="bad param", default=b"cant serialize me") + _ = client.create_param(name="bad param", default=b"cant serialize me") - param = sp_client.create_param(name="number", default=8) + param = client.create_param(name="number", default=8) param_doubled = double_number_input(param) with pytest.raises(InvalidUserArgumentException): _ = param_doubled.get(parameters={"number": b"cant serialize me"}) -def test_get_with_custom_parameter(sp_client): - param = sp_client.create_param(name="number", default=8) +def test_get_with_custom_parameter(client): + param = client.create_param(name="number", default=8) assert param.get() == 8 param_doubled = double_number_input(param)