Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions python/pyspark/pipelines/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from pyspark.pipelines.source_code_location import (
get_caller_source_code_location,
)
from pyspark.pipelines.dataset import (
from pyspark.pipelines.output import (
MaterializedView,
StreamingTable,
TemporaryView,
Expand Down Expand Up @@ -156,7 +156,7 @@ def outer(

resolved_name = name or decorated.__name__
registry = get_active_graph_element_registry()
registry.register_dataset(
registry.register_output(
StreamingTable(
comment=comment,
name=resolved_name,
Expand Down Expand Up @@ -258,7 +258,7 @@ def outer(

resolved_name = name or decorated.__name__
registry = get_active_graph_element_registry()
registry.register_dataset(
registry.register_output(
MaterializedView(
comment=comment,
name=resolved_name,
Expand Down Expand Up @@ -351,7 +351,7 @@ def outer(decorated: QueryFunction) -> None:

resolved_name = name or decorated.__name__
registry = get_active_graph_element_registry()
registry.register_dataset(
registry.register_output(
TemporaryView(
comment=comment,
name=resolved_name,
Expand Down Expand Up @@ -446,4 +446,4 @@ def create_streaming_table(
schema=schema,
format=format,
)
get_active_graph_element_registry().register_dataset(table)
get_active_graph_element_registry().register_output(table)
4 changes: 2 additions & 2 deletions python/pyspark/pipelines/graph_element_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from abc import ABC, abstractmethod
from pathlib import Path

from pyspark.pipelines.dataset import Dataset
from pyspark.pipelines.output import Output
from pyspark.pipelines.flow import Flow
from contextlib import contextmanager
from contextvars import ContextVar
Expand All @@ -35,7 +35,7 @@ class GraphElementRegistry(ABC):
"""

@abstractmethod
def register_dataset(self, dataset: Dataset) -> None:
def register_output(self, output: Output) -> None:
"""Add the given dataset to the registry."""

@abstractmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@


@dataclass(frozen=True)
class Dataset:
"""Base class for definitions of datasets in a pipeline dataflow graph.
class Output:
"""Base class for definitions of outputs in a pipeline dataflow graph.

:param name: The name of the dataset. May be a multi-part name, such as "db.table".
:param comment: Optional comment for the dataset.
:param source_code_location: The location of the source code that created this dataset.
:param name: The name of the outputs. May be a multi-part name, such as "db.table".
:param comment: Optional comment for the output.
:param source_code_location: The location of the source code that created this output.
This is used for debugging and tracing purposes.
"""

Expand All @@ -37,7 +37,7 @@ class Dataset:


@dataclass(frozen=True)
class Table(Dataset):
class Table(Output):
"""
Definition of a table in a pipeline dataflow graph, i.e. a catalog object backed by data in
physical storage.
Expand Down Expand Up @@ -69,7 +69,7 @@ class StreamingTable(Table):


@dataclass(frozen=True)
class TemporaryView(Dataset):
class TemporaryView(Output):
"""Definition of a temporary view in a pipeline dataflow graph. Temporary views can be
referenced by flows within the dataflow graph, but are not visible outside of the graph."""

Expand Down
54 changes: 27 additions & 27 deletions python/pyspark/pipelines/spark_connect_graph_element_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
from pyspark.sql import SparkSession
from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame
from pyspark.pipelines.block_connect_access import block_spark_connect_execution_and_analysis
from pyspark.pipelines.dataset import (
Dataset,
from pyspark.pipelines.output import (
Output,
MaterializedView,
Table,
StreamingTable,
Expand All @@ -37,65 +37,65 @@


class SparkConnectGraphElementRegistry(GraphElementRegistry):
"""Registers datasets and flows in a dataflow graph held in a Spark Connect server."""
"""Registers outputs and flows in a dataflow graph held in a Spark Connect server."""

def __init__(self, spark: SparkSession, dataflow_graph_id: str) -> None:
# Cast because mypy seems to think `spark`` is a function, not an object. Likely related to
# SPARK-47544.
self._client = cast(Any, spark).client
self._dataflow_graph_id = dataflow_graph_id

def register_dataset(self, dataset: Dataset) -> None:
if isinstance(dataset, Table):
if isinstance(dataset.schema, str):
schema_string = dataset.schema
def register_output(self, output: Output) -> None:
if isinstance(output, Table):
if isinstance(output.schema, str):
schema_string = output.schema
schema_data_type = None
elif isinstance(dataset.schema, StructType):
elif isinstance(output.schema, StructType):
schema_string = None
schema_data_type = pyspark_types_to_proto_types(dataset.schema)
schema_data_type = pyspark_types_to_proto_types(output.schema)
else:
schema_string = None
schema_data_type = None

table_details = pb2.PipelineCommand.DefineDataset.TableDetails(
table_properties=dataset.table_properties,
partition_cols=dataset.partition_cols,
format=dataset.format,
table_details = pb2.PipelineCommand.DefineOutput.TableDetails(
table_properties=output.table_properties,
partition_cols=output.partition_cols,
format=output.format,
# Even though schema_string is not required, the generated Python code seems to
# erroneously think it is required.
schema_string=schema_string, # type: ignore[arg-type]
schema_data_type=schema_data_type,
)

if isinstance(dataset, MaterializedView):
dataset_type = pb2.DatasetType.MATERIALIZED_VIEW
elif isinstance(dataset, StreamingTable):
dataset_type = pb2.DatasetType.TABLE
if isinstance(output, MaterializedView):
output_type = pb2.OutputType.MATERIALIZED_VIEW
elif isinstance(output, StreamingTable):
output_type = pb2.OutputType.TABLE
else:
raise PySparkTypeError(
errorClass="UNSUPPORTED_PIPELINES_DATASET_TYPE",
messageParameters={"dataset_type": type(dataset).__name__},
messageParameters={"output_type": type(output).__name__},
)
elif isinstance(dataset, TemporaryView):
dataset_type = pb2.DatasetType.TEMPORARY_VIEW
elif isinstance(output, TemporaryView):
output_type = pb2.OutputType.TEMPORARY_VIEW
table_details = None
else:
raise PySparkTypeError(
errorClass="UNSUPPORTED_PIPELINES_DATASET_TYPE",
messageParameters={"dataset_type": type(dataset).__name__},
messageParameters={"output_type": type(output).__name__},
)

inner_command = pb2.PipelineCommand.DefineDataset(
inner_command = pb2.PipelineCommand.DefineOutput(
dataflow_graph_id=self._dataflow_graph_id,
dataset_name=dataset.name,
dataset_type=dataset_type,
comment=dataset.comment,
output_name=output.name,
output_type=output_type,
comment=output.comment,
table_details=table_details,
source_code_location=source_code_location_to_proto(dataset.source_code_location),
source_code_location=source_code_location_to_proto(output.source_code_location),
)

command = pb2.Command()
command.pipeline_command.define_dataset.CopyFrom(inner_command)
command.pipeline_command.define_output.CopyFrom(inner_command)
self._client.execute_command(command)

def register_flow(self, flow: Flow) -> None:
Expand Down
12 changes: 6 additions & 6 deletions python/pyspark/pipelines/tests/local_graph_element_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from pathlib import Path
from typing import List, Sequence

from pyspark.pipelines.dataset import Dataset
from pyspark.pipelines.output import Output
from pyspark.pipelines.flow import Flow
from pyspark.pipelines.graph_element_registry import GraphElementRegistry

Expand All @@ -32,12 +32,12 @@ class SqlFile:

class LocalGraphElementRegistry(GraphElementRegistry):
def __init__(self) -> None:
self._datasets: List[Dataset] = []
self._outputs: List[Output] = []
self._flows: List[Flow] = []
self._sql_files: List[SqlFile] = []

def register_dataset(self, dataset: Dataset) -> None:
self._datasets.append(dataset)
def register_output(self, output: Output) -> None:
self._outputs.append(output)

def register_flow(self, flow: Flow) -> None:
self._flows.append(flow)
Expand All @@ -46,8 +46,8 @@ def register_sql(self, sql_text: str, file_path: Path) -> None:
self._sql_files.append(SqlFile(sql_text, file_path))

@property
def datasets(self) -> Sequence[Dataset]:
return self._datasets
def outputs(self) -> Sequence[Output]:
return self._outputs

@property
def flows(self) -> Sequence[Flow]:
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/pipelines/tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,8 +295,8 @@ def mv2():

registry = LocalGraphElementRegistry()
register_definitions(outer_dir / "pipeline.yaml", registry, spec)
self.assertEqual(len(registry.datasets), 1)
self.assertEqual(registry.datasets[0].name, "mv1")
self.assertEqual(len(registry.outputs), 1)
self.assertEqual(registry.outputs[0].name, "mv1")

def test_register_definitions_file_raises_error(self):
"""Errors raised while executing definitions code should make it to the outer context."""
Expand Down
8 changes: 4 additions & 4 deletions python/pyspark/pipelines/tests/test_graph_element_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ def flow1():
def flow2():
raise NotImplementedError()

self.assertEqual(len(registry.datasets), 3)
self.assertEqual(len(registry.outputs), 3)
self.assertEqual(len(registry.flows), 4)

mv_obj = registry.datasets[0]
mv_obj = registry.outputs[0]
self.assertEqual(mv_obj.name, "mv")
assert mv_obj.source_code_location.filename.endswith("test_graph_element_registry.py")

Expand All @@ -58,7 +58,7 @@ def flow2():
self.assertEqual(mv_flow_obj.target, "mv")
assert mv_flow_obj.source_code_location.filename.endswith("test_graph_element_registry.py")

st_obj = registry.datasets[1]
st_obj = registry.outputs[1]
self.assertEqual(st_obj.name, "st")
assert st_obj.source_code_location.filename.endswith("test_graph_element_registry.py")

Expand All @@ -67,7 +67,7 @@ def flow2():
self.assertEqual(st_flow_obj.target, "st")
assert mv_flow_obj.source_code_location.filename.endswith("test_graph_element_registry.py")

st2_obj = registry.datasets[2]
st2_obj = registry.outputs[2]
self.assertEqual(st2_obj.name, "st2")
assert st2_obj.source_code_location.filename.endswith("test_graph_element_registry.py")

Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/pipelines/tests/test_init_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ def test_init(self):
assert spec.name == project_name
registry = LocalGraphElementRegistry()
register_definitions(spec_path, registry, spec)
self.assertEqual(len(registry.datasets), 1)
self.assertEqual(registry.datasets[0].name, "example_python_materialized_view")
self.assertEqual(len(registry.outputs), 1)
self.assertEqual(registry.outputs[0].name, "example_python_materialized_view")
self.assertEqual(len(registry.flows), 1)
self.assertEqual(registry.flows[0].name, "example_python_materialized_view")
self.assertEqual(registry.flows[0].target, "example_python_materialized_view")
Expand Down
Loading