Skip to content

Commit

Permalink
Allow users to realize artifacts with custom parameters (#96)
Browse files Browse the repository at this point in the history
  • Loading branch information
kenxu95 authored Jun 13, 2022
1 parent e6bbc9b commit 27c6e89
Show file tree
Hide file tree
Showing 9 changed files with 115 additions and 26 deletions.
23 changes: 23 additions & 0 deletions integration_tests/sdk/param_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pytest

from aqueduct.error import InvalidUserArgumentException
from constants import SENTIMENT_SQL_QUERY
from utils import get_integration_name, run_flow_test
from aqueduct import metric, op
Expand Down Expand Up @@ -51,6 +52,28 @@ def test_basic_param_creation(client):
assert kv_df.get().equals(pd.DataFrame(data=kv))


def test_non_jsonable_parameter(client):
with pytest.raises(InvalidUserArgumentException):
_ = client.create_param(name="bad param", default=b"cant serialize me")

param = client.create_param(name="number", default=8)
param_doubled = double_number_input(param)
with pytest.raises(InvalidUserArgumentException):
_ = param_doubled.get(parameters={"number": b"cant serialize me"})


def test_get_with_custom_parameter(client):
param = client.create_param(name="number", default=8)
assert param.get() == 8

param_doubled = double_number_input(param)
assert param_doubled.get(parameters={"number": 20}) == 40
assert param_doubled.get() == 2 * 8

with pytest.raises(InvalidUserArgumentException):
param_doubled.get(parameters={"non-existant param": 10})


@op
def append_row_to_df(df, row):
"""`row` is a list of values to append to the input dataframe."""
Expand Down
12 changes: 3 additions & 9 deletions sdk/aqueduct/aqueduct_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from .integrations.salesforce_integration import SalesforceIntegration
from .integrations.google_sheets_integration import GoogleSheetsIntegration
from .integrations.s3_integration import S3Integration
from .operators import Operator, ParamSpec, OperatorSpec
from .operators import Operator, ParamSpec, OperatorSpec, serialize_parameter_value
from .param_artifact import ParamArtifact
from .utils import (
schedule_from_cron_string,
Expand Down Expand Up @@ -128,13 +128,7 @@ def create_param(self, name: str, default: Any, description: str = "") -> ParamA
if default is None:
raise InvalidUserArgumentException("Parameter default value cannot be None.")

# Check that the supplied value is JSON-able.
try:
serialized_default = str(json.dumps(default))
except Exception as e:
raise InvalidUserArgumentException(
"Provided parameter must be able to be converted into a JSON object: %s" % str(e)
)
val = serialize_parameter_value(name, default)

operator_id = generate_uuid()
output_artifact_id = generate_uuid()
Expand All @@ -146,7 +140,7 @@ def create_param(self, name: str, default: Any, description: str = "") -> ParamA
id=operator_id,
name=name,
description=description,
spec=OperatorSpec(param=ParamSpec(val=serialized_default)),
spec=OperatorSpec(param=ParamSpec(val=val)),
inputs=[],
outputs=[output_artifact_id],
),
Expand Down
11 changes: 8 additions & 3 deletions sdk/aqueduct/check_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
import json

import uuid
from typing import Optional, Dict, Any

from aqueduct.utils import get_description_for_check

from aqueduct.api_client import APIClient
from aqueduct.dag import DAG, apply_deltas_to_dag, SubgraphDAGDelta
from aqueduct.dag import DAG, apply_deltas_to_dag, SubgraphDAGDelta, UpdateParametersDelta
from aqueduct.error import AqueductError

from aqueduct.generic_artifact import Artifact
Expand Down Expand Up @@ -40,7 +42,7 @@ def __init__(
self._dag = dag
self._artifact_id = artifact_id

def get(self) -> bool:
def get(self, parameters: Optional[Dict[str, Any]] = None) -> bool:
"""Materializes a CheckArtifact into a boolean.
Returns:
Expand All @@ -58,7 +60,10 @@ def get(self) -> bool:
SubgraphDAGDelta(
artifact_ids=[self._artifact_id],
include_load_operators=False,
)
),
UpdateParametersDelta(
parameters=parameters,
),
],
make_copy=True,
)
Expand Down
47 changes: 44 additions & 3 deletions sdk/aqueduct/dag.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
import copy
import uuid
from typing import List, Optional, Dict
from typing import List, Optional, Dict, Any
from abc import ABC, abstractmethod

from pydantic import BaseModel
from aqueduct.error import (
InternalAqueductError,
InvalidUserActionException,
ArtifactNotFoundException,
InvalidUserArgumentException,
)

from aqueduct.artifact import Artifact
from aqueduct.enums import OperatorType, TriggerType
from aqueduct.operators import Operator, get_operator_type
from aqueduct.operators import Operator, get_operator_type, serialize_parameter_value


class Schedule(BaseModel):
Expand Down Expand Up @@ -190,7 +191,8 @@ def list_artifacts(

return [artifact for artifact in self.artifacts.values()]

# DAG WRITES
######################## DAG WRITES #############################

def add_operator(self, op: Operator) -> None:
self.add_operators([op])

Expand All @@ -203,6 +205,11 @@ def add_artifacts(self, artifacts: List[Artifact]) -> None:
for artifact in artifacts:
self.artifacts[str(artifact.id)] = artifact

def update_operator(self, op: Operator) -> None:
"""Blind replace of an operator in the dag."""
self.operators[str(op.id)] = op
self.operator_by_name[op.name] = op

def remove_operator(
self,
operator_id: uuid.UUID,
Expand Down Expand Up @@ -430,6 +437,40 @@ def apply(self, dag: DAG) -> None:
)


class UpdateParametersDelta(DAGDelta):
"""Updates the values of the given parameters in the DAG to the given values. No-ops if no parameters provided."""

def __init__(
self,
parameters: Optional[Dict[str, Any]],
):
self.parameters = parameters

def apply(self, dag: DAG) -> None:
if self.parameters is None:
return

for param_name, new_val in self.parameters.items():
param_op = dag.get_operator(with_name=param_name)
if param_op is None:
raise InvalidUserArgumentException(
"Parameter %s cannot be found, or is not utilized in the current computation."
% param_name
)
if get_operator_type(param_op) != OperatorType.PARAM:
raise InvalidUserArgumentException(
"Parameter %s must refer to a parameter, but instead refers to a: %s"
% (param_name, get_operator_type(param_op))
)

# Update the parameter value and update the dag.
assert param_op.spec.param # for mypy
param_op.spec.param.val = serialize_parameter_value(
param_name, self.parameters[param_name]
)
dag.update_operator(param_op)


def apply_deltas_to_dag(dag: DAG, deltas: List[DAGDelta], make_copy: bool = False) -> DAG:
if make_copy:
dag = copy.deepcopy(dag)
Expand Down
4 changes: 2 additions & 2 deletions sdk/aqueduct/generic_artifact.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
import json
from typing import Any, Dict
from typing import Any, Dict, Optional
import uuid
from aqueduct.dag import DAG

Expand Down Expand Up @@ -34,5 +34,5 @@ def describe(self) -> None:
pass

@abstractmethod
def get(self) -> Any:
def get(self, parameters: Optional[Dict[str, Any]] = None) -> Any:
pass
8 changes: 6 additions & 2 deletions sdk/aqueduct/metric_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
DAG,
SubgraphDAGDelta,
RemoveCheckOperatorDelta,
UpdateParametersDelta,
)
from aqueduct.error import AqueductError

Expand Down Expand Up @@ -56,7 +57,7 @@ def __init__(
self._dag = dag
self._artifact_id = artifact_id

def get(self) -> float:
def get(self, parameters: Optional[Dict[str, Any]] = None) -> float:
"""Materializes a MetricArtifact into its immediate float value.
Returns:
Expand All @@ -74,7 +75,10 @@ def get(self) -> float:
SubgraphDAGDelta(
artifact_ids=[self._artifact_id],
include_load_operators=False,
)
),
UpdateParametersDelta(
parameters=parameters,
),
],
make_copy=True,
)
Expand Down
16 changes: 14 additions & 2 deletions sdk/aqueduct/operators.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import List, Optional, Union
import json
from typing import List, Optional, Union, Any
import uuid

from pydantic import BaseModel
Expand All @@ -15,7 +16,7 @@
LoadUpdateMode,
CheckSeverity,
)
from aqueduct.error import AqueductError
from aqueduct.error import AqueductError, InvalidUserArgumentException
from aqueduct.integrations.integration import IntegrationInfo


Expand Down Expand Up @@ -186,3 +187,14 @@ def get_operator_type(operator: Operator) -> OperatorType:
return OperatorType.PARAM
else:
raise AqueductError("Invalid operator type")


def serialize_parameter_value(name: str, val: Any) -> str:
"""A parameter must be JSON serializable."""
try:
return str(json.dumps(val))
except Exception as e:
raise InvalidUserArgumentException(
"Provided parameter %s must be able to be converted into a JSON object: %s"
% (name, str(e))
)
10 changes: 8 additions & 2 deletions sdk/aqueduct/param_artifact.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import json
import uuid
from typing import Any
from typing import Any, Dict, Optional

from aqueduct.api_client import APIClient
from aqueduct.dag import DAG
from aqueduct.error import InvalidUserArgumentException
from aqueduct.generic_artifact import Artifact


Expand All @@ -20,7 +21,12 @@ def __init__(
self._dag = dag
self._artifact_id = artifact_id

def get(self) -> Any:
def get(self, parameters: Optional[Dict[str, Any]] = None) -> Any:
if parameters is not None:
raise InvalidUserArgumentException(
"Parameters cannot be supplied to parameter artifacts."
)

_ = self._dag.must_get_artifact(self._artifact_id)
param_op = self._dag.must_get_operator(with_output_artifact_id=self._artifact_id)
assert param_op.spec.param is not None, "Artifact is not a parameter."
Expand Down
10 changes: 7 additions & 3 deletions sdk/aqueduct/table_artifact.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import json
from typing import Callable, Any
from typing import Callable, Any, Optional, Dict
import uuid

import pandas as pd
Expand All @@ -15,6 +15,7 @@
AddOrReplaceOperatorDelta,
SubgraphDAGDelta,
RemoveCheckOperatorDelta,
UpdateParametersDelta,
)
from aqueduct.enums import OperatorType, FunctionType, FunctionGranularity
from aqueduct.error import (
Expand Down Expand Up @@ -73,7 +74,7 @@ def __init__(
self._dag = dag
self._artifact_id = artifact_id

def get(self) -> pd.DataFrame:
def get(self, parameters: Optional[Dict[str, Any]] = None) -> pd.DataFrame:
"""Materializes TableArtifact into an actual dataframe.
Returns:
Expand All @@ -91,7 +92,10 @@ def get(self) -> pd.DataFrame:
SubgraphDAGDelta(
artifact_ids=[self._artifact_id],
include_load_operators=False,
)
),
UpdateParametersDelta(
parameters=parameters,
),
],
make_copy=True,
)
Expand Down

0 comments on commit 27c6e89

Please sign in to comment.