Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow users to realize artifacts with custom parameters #96

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions integration_tests/sdk/param_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pytest

from aqueduct.error import InvalidUserArgumentException
from constants import SENTIMENT_SQL_QUERY
from utils import get_integration_name, run_flow_test
from aqueduct import metric, op
Expand Down Expand Up @@ -51,6 +52,28 @@ def test_basic_param_creation(client):
assert kv_df.get().equals(pd.DataFrame(data=kv))


def test_non_jsonable_parameter(client):
with pytest.raises(InvalidUserArgumentException):
_ = client.create_param(name="bad param", default=b"cant serialize me")

param = client.create_param(name="number", default=8)
param_doubled = double_number_input(param)
with pytest.raises(InvalidUserArgumentException):
_ = param_doubled.get(parameters={"number": b"cant serialize me"})


def test_get_with_custom_parameter(client):
param = client.create_param(name="number", default=8)
assert param.get() == 8

param_doubled = double_number_input(param)
assert param_doubled.get(parameters={"number": 20}) == 40
assert param_doubled.get() == 2 * 8

with pytest.raises(InvalidUserArgumentException):
param_doubled.get(parameters={"non-existant param": 10})


@op
def append_row_to_df(df, row):
"""`row` is a list of values to append to the input dataframe."""
Expand Down
12 changes: 3 additions & 9 deletions sdk/aqueduct/aqueduct_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from .integrations.salesforce_integration import SalesforceIntegration
from .integrations.google_sheets_integration import GoogleSheetsIntegration
from .integrations.s3_integration import S3Integration
from .operators import Operator, ParamSpec, OperatorSpec
from .operators import Operator, ParamSpec, OperatorSpec, serialize_parameter_value
from .param_artifact import ParamArtifact
from .utils import (
schedule_from_cron_string,
Expand Down Expand Up @@ -128,13 +128,7 @@ def create_param(self, name: str, default: Any, description: str = "") -> ParamA
if default is None:
raise InvalidUserArgumentException("Parameter default value cannot be None.")

# Check that the supplied value is JSON-able.
try:
serialized_default = str(json.dumps(default))
except Exception as e:
raise InvalidUserArgumentException(
"Provided parameter must be able to be converted into a JSON object: %s" % str(e)
)
val = serialize_parameter_value(name, default)

operator_id = generate_uuid()
output_artifact_id = generate_uuid()
Expand All @@ -146,7 +140,7 @@ def create_param(self, name: str, default: Any, description: str = "") -> ParamA
id=operator_id,
name=name,
description=description,
spec=OperatorSpec(param=ParamSpec(val=serialized_default)),
spec=OperatorSpec(param=ParamSpec(val=val)),
inputs=[],
outputs=[output_artifact_id],
),
Expand Down
11 changes: 8 additions & 3 deletions sdk/aqueduct/check_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
import json

import uuid
from typing import Optional, Dict, Any

from aqueduct.utils import get_description_for_check

from aqueduct.api_client import APIClient
from aqueduct.dag import DAG, apply_deltas_to_dag, SubgraphDAGDelta
from aqueduct.dag import DAG, apply_deltas_to_dag, SubgraphDAGDelta, UpdateParametersDelta
from aqueduct.error import AqueductError

from aqueduct.generic_artifact import Artifact
Expand Down Expand Up @@ -40,7 +42,7 @@ def __init__(
self._dag = dag
self._artifact_id = artifact_id

def get(self) -> bool:
def get(self, parameters: Optional[Dict[str, Any]] = None) -> bool:
"""Materializes a CheckArtifact into a boolean.

Returns:
Expand All @@ -58,7 +60,10 @@ def get(self) -> bool:
SubgraphDAGDelta(
artifact_ids=[self._artifact_id],
include_load_operators=False,
)
),
UpdateParametersDelta(
parameters=parameters,
),
],
make_copy=True,
)
Expand Down
47 changes: 44 additions & 3 deletions sdk/aqueduct/dag.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
import copy
import uuid
from typing import List, Optional, Dict
from typing import List, Optional, Dict, Any
from abc import ABC, abstractmethod

from pydantic import BaseModel
from aqueduct.error import (
InternalAqueductError,
InvalidUserActionException,
ArtifactNotFoundException,
InvalidUserArgumentException,
)

from aqueduct.artifact import Artifact
from aqueduct.enums import OperatorType, TriggerType
from aqueduct.operators import Operator, get_operator_type
from aqueduct.operators import Operator, get_operator_type, serialize_parameter_value


class Schedule(BaseModel):
Expand Down Expand Up @@ -190,7 +191,8 @@ def list_artifacts(

return [artifact for artifact in self.artifacts.values()]

# DAG WRITES
######################## DAG WRITES #############################

def add_operator(self, op: Operator) -> None:
self.add_operators([op])

Expand All @@ -203,6 +205,11 @@ def add_artifacts(self, artifacts: List[Artifact]) -> None:
for artifact in artifacts:
self.artifacts[str(artifact.id)] = artifact

def update_operator(self, op: Operator) -> None:
"""Blind replace of an operator in the dag."""
self.operators[str(op.id)] = op
self.operator_by_name[op.name] = op

def remove_operator(
self,
operator_id: uuid.UUID,
Expand Down Expand Up @@ -430,6 +437,40 @@ def apply(self, dag: DAG) -> None:
)


class UpdateParametersDelta(DAGDelta):
"""Updates the values of the given parameters in the DAG to the given values. No-ops if no parameters provided."""

def __init__(
self,
parameters: Optional[Dict[str, Any]],
):
self.parameters = parameters

def apply(self, dag: DAG) -> None:
if self.parameters is None:
return

for param_name, new_val in self.parameters.items():
param_op = dag.get_operator(with_name=param_name)
if param_op is None:
raise InvalidUserArgumentException(
"Parameter %s cannot be found, or is not utilized in the current computation."
% param_name
)
if get_operator_type(param_op) != OperatorType.PARAM:
raise InvalidUserArgumentException(
"Parameter %s must refer to a parameter, but instead refers to a: %s"
% (param_name, get_operator_type(param_op))
)

# Update the parameter value and update the dag.
assert param_op.spec.param # for mypy
param_op.spec.param.val = serialize_parameter_value(
param_name, self.parameters[param_name]
)
dag.update_operator(param_op)


def apply_deltas_to_dag(dag: DAG, deltas: List[DAGDelta], make_copy: bool = False) -> DAG:
if make_copy:
dag = copy.deepcopy(dag)
Expand Down
4 changes: 2 additions & 2 deletions sdk/aqueduct/generic_artifact.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
import json
from typing import Any, Dict
from typing import Any, Dict, Optional
import uuid
from aqueduct.dag import DAG

Expand Down Expand Up @@ -34,5 +34,5 @@ def describe(self) -> None:
pass

@abstractmethod
def get(self) -> Any:
def get(self, parameters: Optional[Dict[str, Any]] = None) -> Any:
pass
8 changes: 6 additions & 2 deletions sdk/aqueduct/metric_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
DAG,
SubgraphDAGDelta,
RemoveCheckOperatorDelta,
UpdateParametersDelta,
)
from aqueduct.error import AqueductError

Expand Down Expand Up @@ -56,7 +57,7 @@ def __init__(
self._dag = dag
self._artifact_id = artifact_id

def get(self) -> float:
def get(self, parameters: Optional[Dict[str, Any]] = None) -> float:
"""Materializes a MetricArtifact into its immediate float value.

Returns:
Expand All @@ -74,7 +75,10 @@ def get(self) -> float:
SubgraphDAGDelta(
artifact_ids=[self._artifact_id],
include_load_operators=False,
)
),
UpdateParametersDelta(
parameters=parameters,
),
],
make_copy=True,
)
Expand Down
16 changes: 14 additions & 2 deletions sdk/aqueduct/operators.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import List, Optional, Union
import json
from typing import List, Optional, Union, Any
import uuid

from pydantic import BaseModel
Expand All @@ -15,7 +16,7 @@
LoadUpdateMode,
CheckSeverity,
)
from aqueduct.error import AqueductError
from aqueduct.error import AqueductError, InvalidUserArgumentException
from aqueduct.integrations.integration import IntegrationInfo


Expand Down Expand Up @@ -186,3 +187,14 @@ def get_operator_type(operator: Operator) -> OperatorType:
return OperatorType.PARAM
else:
raise AqueductError("Invalid operator type")


def serialize_parameter_value(name: str, val: Any) -> str:
"""A parameter must be JSON serializable."""
try:
return str(json.dumps(val))
except Exception as e:
raise InvalidUserArgumentException(
"Provided parameter %s must be able to be converted into a JSON object: %s"
% (name, str(e))
)
10 changes: 8 additions & 2 deletions sdk/aqueduct/param_artifact.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import json
import uuid
from typing import Any
from typing import Any, Dict, Optional

from aqueduct.api_client import APIClient
from aqueduct.dag import DAG
from aqueduct.error import InvalidUserArgumentException
from aqueduct.generic_artifact import Artifact


Expand All @@ -20,7 +21,12 @@ def __init__(
self._dag = dag
self._artifact_id = artifact_id

def get(self) -> Any:
def get(self, parameters: Optional[Dict[str, Any]] = None) -> Any:
if parameters is not None:
raise InvalidUserArgumentException(
"Parameters cannot be supplied to parameter artifacts."
)

_ = self._dag.must_get_artifact(self._artifact_id)
param_op = self._dag.must_get_operator(with_output_artifact_id=self._artifact_id)
assert param_op.spec.param is not None, "Artifact is not a parameter."
Expand Down
10 changes: 7 additions & 3 deletions sdk/aqueduct/table_artifact.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import json
from typing import Callable, Any
from typing import Callable, Any, Optional, Dict
import uuid

import pandas as pd
Expand All @@ -15,6 +15,7 @@
AddOrReplaceOperatorDelta,
SubgraphDAGDelta,
RemoveCheckOperatorDelta,
UpdateParametersDelta,
)
from aqueduct.enums import OperatorType, FunctionType, FunctionGranularity
from aqueduct.error import (
Expand Down Expand Up @@ -73,7 +74,7 @@ def __init__(
self._dag = dag
self._artifact_id = artifact_id

def get(self) -> pd.DataFrame:
def get(self, parameters: Optional[Dict[str, Any]] = None) -> pd.DataFrame:
"""Materializes TableArtifact into an actual dataframe.

Returns:
Expand All @@ -91,7 +92,10 @@ def get(self) -> pd.DataFrame:
SubgraphDAGDelta(
artifact_ids=[self._artifact_id],
include_load_operators=False,
)
),
UpdateParametersDelta(
parameters=parameters,
),
],
make_copy=True,
)
Expand Down