Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Spark Python Executor #1231

Merged
merged 10 commits into from
Apr 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 115 additions & 23 deletions src/python/aqueduct_executor/operators/connectors/data/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,49 @@ def run(spec: Spec) -> None:
Arguments:
- spec: The spec provided for this operator.
"""
return execute_data_spec(
spec=spec,
read_artifacts_func=utils.read_artifacts,
write_artifact_func=utils.write_artifact,
setup_connector_func=setup_connector,
is_spark=False,
)


def execute_data_spec(
spec: Spec,
read_artifacts_func: Any,
write_artifact_func: Any,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are already parsing function objects, does it make sense to pass different function objects based on whether it's spark or not? In this way we don't even need to pass is_spark and other stuff as arguments.

For example, could we do something like

if is_spark:
 run_helper(spec, read_artifact_func=utils.read_spark_artifacts, write_artifact_func=utils.write_spark_artifacts, ...)

setup_connector_func: Any,
is_spark: bool,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we have to pass this arg around, does it makes sense to explicitly pass a single Optional[spark.Session] object to decide if spark is enabled? Also I'd like to remove **kwargs and since it's not clear what to expect and how it's used. This pattern is more useful in cases like decorators where we support arbitrary inputs, which is not the case here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the correct pattern, however it requires that we import pyspark.sql in the regular code path, which we want to avoid. The kwargs pattern was to avoid the import in non-Spark environments.

**kwargs: Any,
) -> None:
"""
This function executes the spec provided. If run in a Spark environment, it uses
the Spark specific utils functions to read/write to storage layer and to setup connectors.
The only kwarg we expect is spark_session_obj
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Can we use the same docstring format as we do in the SDK (see the ones in client.py)?


Args:
spec: The spec provided for this operator.
read_artifacts_func: function used to read artifacts from storage layer
write_artifact_func: function used to write artifacts to storage layer
setup_connector_func: function to use to setup the connectors
is_spark Whether or not we are running in a Spark env.
"""
storage = parse_storage(spec.storage_config)
exec_state = ExecutionState(user_logs=Logs())

try:
_execute(spec, storage, exec_state)
_execute(
spec,
storage,
exec_state,
read_artifacts_func,
write_artifact_func,
setup_connector_func,
is_spark,
**kwargs,
)
# Write operator execution metadata
# Each decorator may set exec_state.status to FAILED, but if none of them did, then we are
# certain that the operator succeeded.
Expand Down Expand Up @@ -86,23 +124,46 @@ def run(spec: Spec) -> None:
sys.exit(1)


def _execute(spec: Spec, storage: Storage, exec_state: ExecutionState) -> None:
def _execute(
spec: Spec,
storage: Storage,
exec_state: ExecutionState,
read_artifacts_func: Any,
write_artifact_func: Any,
setup_connector_func: Any,
is_spark: bool,
**kwargs: Any,
) -> None:
if spec.type == JobType.DELETESAVEDOBJECTS:
run_delete_saved_objects(spec, storage, exec_state)

# Because constructing certain connectors (eg. Postgres) can also involve authentication,
# we do both in `run_authenticate()`, and give a more helpful error message on failure.
elif spec.type == JobType.AUTHENTICATE:
run_authenticate(spec, exec_state, is_demo=(spec.name == AQUEDUCT_DEMO_NAME))
run_authenticate(
spec,
exec_state,
is_demo=(spec.name == AQUEDUCT_DEMO_NAME),
setup_connector_func=setup_connector_func,
)

else:
op = setup_connector(spec.connector_name, spec.connector_config)
op = setup_connector_func(spec.connector_name, spec.connector_config)
if spec.type == JobType.EXTRACT:
run_extract(spec, op, storage, exec_state)
run_extract(
spec,
op,
storage,
exec_state,
read_artifacts_func,
write_artifact_func,
is_spark,
**kwargs,
)
elif spec.type == JobType.LOADTABLE:
run_load_table(spec, op, storage)
run_load_table(spec, op, storage, is_spark)
elif spec.type == JobType.LOAD:
run_load(spec, op, storage, exec_state)
run_load(spec, op, storage, exec_state, read_artifacts_func, is_spark, **kwargs)
elif spec.type == JobType.DISCOVER:
run_discover(spec, op, storage)
else:
Expand All @@ -113,19 +174,27 @@ def run_authenticate(
spec: AuthenticateSpec,
exec_state: ExecutionState,
is_demo: bool,
setup_connector_func: Any,
) -> None:
@exec_state.user_fn_redirected(
failure_tip=TIP_DEMO_CONNECTION if is_demo else TIP_INTEGRATION_CONNECTION
)
def _authenticate() -> None:
op = setup_connector(spec.connector_name, spec.connector_config)
op = setup_connector_func(spec.connector_name, spec.connector_config)
op.authenticate()

_authenticate()


def run_extract(
spec: ExtractSpec, op: connector.DataConnector, storage: Storage, exec_state: ExecutionState
spec: ExtractSpec,
op: connector.DataConnector,
storage: Storage,
exec_state: ExecutionState,
read_artifacts_func: Any,
write_artifact_func: Any,
is_spark: bool,
**kwargs: Any,
) -> None:
extract_params = spec.parameters

Expand All @@ -134,10 +203,11 @@ def run_extract(
if isinstance(extract_params, extract.RelationalParams) or isinstance(
extract_params, extract.MongoDBParams
):
input_vals, _, _ = utils.read_artifacts(
storage,
spec.input_content_paths,
spec.input_metadata_paths,
input_vals, _, _ = read_artifacts_func(
storage=storage,
input_paths=spec.input_content_paths,
input_metadata_paths=spec.input_metadata_paths,
**kwargs,
)
assert all(
isinstance(param_val, str) for param_val in input_vals
Expand All @@ -146,7 +216,10 @@ def run_extract(

@exec_state.user_fn_redirected(failure_tip=TIP_EXTRACT)
def _extract() -> Any:
return op.extract(spec.parameters)
if is_spark:
return op.extract_spark(spec.parameters, **kwargs) # type: ignore
else:
return op.extract(spec.parameters)

output = _extract()

Expand All @@ -160,14 +233,15 @@ def _extract() -> Any:
output_artifact_type = ArtifactType.TUPLE

if exec_state.status != ExecutionStatus.FAILED:
utils.write_artifact(
write_artifact_func(
storage,
output_artifact_type,
derived_from_bson,
spec.output_content_path,
spec.output_metadata_path,
output,
system_metadata={},
**kwargs,
)


Expand All @@ -181,26 +255,44 @@ def run_delete_saved_objects(spec: Spec, storage: Storage, exec_state: Execution


def run_load(
spec: LoadSpec, op: connector.DataConnector, storage: Storage, exec_state: ExecutionState
spec: LoadSpec,
op: connector.DataConnector,
storage: Storage,
exec_state: ExecutionState,
read_artifacts_func: Any,
is_spark: bool,
**kwargs: Any,
) -> None:
inputs, input_types, _ = utils.read_artifacts(
storage,
[spec.input_content_path],
[spec.input_metadata_path],
inputs, input_types, _ = read_artifacts_func(
storage=storage,
input_paths=[spec.input_content_path],
input_metadata_paths=[spec.input_metadata_path],
**kwargs,
)
if len(inputs) != 1:
raise Exception("Expected 1 input artifact, but got %d" % len(inputs))

@exec_state.user_fn_redirected(failure_tip=TIP_LOAD)
def _load() -> None:
op.load(spec.parameters, inputs[0], input_types[0])
if is_spark:
op.load_spark(spec.parameters, inputs[0], input_types[0]) # type: ignore
else:
op.load(spec.parameters, inputs[0], input_types[0])

_load()


def run_load_table(spec: LoadTableSpec, op: connector.DataConnector, storage: Storage) -> None:
def run_load_table(
spec: LoadTableSpec,
op: connector.DataConnector,
storage: Storage,
is_spark: bool,
) -> None:
df = utils._read_csv(storage.get(spec.csv))
op.load(spec.load_parameters.parameters, df, ArtifactType.TABLE)
if is_spark:
op.load_spark(spec.load_parameters.parameters, df, ArtifactType.TABLE) # type: ignore
else:
op.load(spec.load_parameters.parameters, df, ArtifactType.TABLE)


def run_discover(spec: DiscoverSpec, op: connector.DataConnector, storage: Storage) -> None:
Expand Down
77 changes: 58 additions & 19 deletions src/python/aqueduct_executor/operators/function_executor/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def _invoke() -> Any:
def _validate_result_count_and_infer_type(
spec: FunctionSpec,
results: List[Any],
infer_type_func: Any,
) -> List[ArtifactType]:
"""
Validates that the expected number of results were returned by the Function
Expand All @@ -164,27 +165,30 @@ def _validate_result_count_and_infer_type(
% (len(spec.output_content_paths), len(results)),
)

return [infer_artifact_type(res) for res in results]
return [infer_type_func(res) for res in results]


def _write_artifacts(
write_artifact_func: Any,
results: Any,
result_types: List[ArtifactType],
derived_from_bson: bool,
output_content_paths: List[str],
output_metadata_paths: List[str],
system_metadata: Any,
storage: Storage,
**kwargs: Any,
) -> None:
for i, result in enumerate(results):
utils.write_artifact(
write_artifact_func(
storage,
result_types[i],
derived_from_bson,
output_content_paths[i],
output_metadata_paths[i],
result,
system_metadata=system_metadata,
**kwargs,
)


Expand Down Expand Up @@ -239,6 +243,32 @@ def run(spec: FunctionSpec) -> None:
"""
Executes a function operator.
"""
execute_function_spec(
spec=spec,
read_artifacts_func=utils.read_artifacts,
write_artifact_func=utils.write_artifact,
infer_type_func=infer_artifact_type,
)


def execute_function_spec(
spec: FunctionSpec,
read_artifacts_func: Any,
write_artifact_func: Any,
infer_type_func: Any,
**kwargs: Any,
) -> None:
"""
Executes a function operator. If run in a Spark environment, it uses the Spark specific utils
functions to read/write to storage layer and to infer the type of artifact.
The only kwarg we expect is spark_session_obj.

Args:
spec: The spec provided for this operator.
read_artifacts_func: function used to read artifacts from storage layer
write_artifact_func: function used to write artifacts to storage layer
infer_type_func: function used to infer type of artifacts returned by operators.
"""
exec_state = ExecutionState(user_logs=Logs())
storage = parse_storage(spec.storage_config)
try:
Expand All @@ -247,7 +277,12 @@ def run(spec: FunctionSpec) -> None:
# Read the input data from intermediate storage.
inputs, _, serialization_types = time_it(
job_name=spec.name, job_type=spec.type.value, step="Reading Inputs"
)(utils.read_artifacts)(storage, spec.input_content_paths, spec.input_metadata_paths)
)(read_artifacts_func)(
storage=storage,
input_paths=spec.input_content_paths,
input_metadata_paths=spec.input_metadata_paths,
**kwargs,
)

# We need to check for BSON_TABLE serialization type at both the top level
# and within any serialized pickled collection (if it exists).
Expand Down Expand Up @@ -275,7 +310,9 @@ def run(spec: FunctionSpec) -> None:

print("Function invoked successfully!")

result_types = _validate_result_count_and_infer_type(spec, results)
result_types = _validate_result_count_and_infer_type(
spec=spec, results=results, infer_type_func=infer_type_func
)

# Perform type checking on the function output.
if spec.operator_type == OperatorType.METRIC:
Expand Down Expand Up @@ -316,15 +353,15 @@ def run(spec: FunctionSpec) -> None:
# not before recording the output artifact value (which will be False).
if not check_passed:
print(f"Check Operator did not pass.")

utils.write_artifact(
storage,
ArtifactType.BOOL,
derived_from_bson, # derived_from_bson doesn't apply to bool artifact
spec.output_content_paths[0],
spec.output_metadata_paths[0],
check_passed,
write_artifact_func(
storage=storage,
artifact_type=ArtifactType.BOOL,
derived_from_bson=derived_from_bson, # derived_from_bson doesn't apply to bool artifact
output_path=spec.output_content_paths[0],
output_metadata_path=spec.output_metadata_paths[0],
content=check_passed,
system_metadata=system_metadata,
**kwargs,
)

check_severity = spec.check_severity
Expand Down Expand Up @@ -359,13 +396,15 @@ def run(spec: FunctionSpec) -> None:
time_it(job_name=spec.name, job_type=spec.type.value, step="Writing Outputs")(
_write_artifacts
)(
results,
result_types,
derived_from_bson,
spec.output_content_paths,
spec.output_metadata_paths,
system_metadata,
storage,
write_artifact_func=write_artifact_func,
results=results,
result_types=result_types,
derived_from_bson=derived_from_bson,
output_content_paths=spec.output_content_paths,
output_metadata_paths=spec.output_metadata_paths,
system_metadata=system_metadata,
storage=storage,
**kwargs,
)

# If we made it here, then the operator has succeeded.
Expand Down
Loading