Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion docs/declarative-pipelines-programming-guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ The `spark-pipelines init` command, described below, makes it easy to generate a

## The `spark-pipelines` Command Line Interface

The `spark-pipelines` command line interface (CLI) is the primary way to execute a pipeline. It also contains an `init` subcommand for generating a pipeline project.
The `spark-pipelines` command line interface (CLI) is the primary way to execute a pipeline. It also contains an `init` subcommand for generating a pipeline project and a `dry-run` subcommand for validating a pipeline.

`spark-pipelines` is built on top of `spark-submit`, meaning that it supports all cluster managers supported by `spark-submit`. It supports all `spark-submit` arguments except for `--class`.

Expand All @@ -106,6 +106,13 @@ The `spark-pipelines` command line interface (CLI) is the primary way to execute

`spark-pipelines run` launches an execution of a pipeline and monitors its progress until it completes. The `--spec` parameter allows selecting the pipeline spec file. If not provided, the CLI will look in the current directory and parent directories for a file named `pipeline.yml` or `pipeline.yaml`.

### `spark-pipelines dry-run`

`spark-pipelines dry-run` launches an execution of a pipeline that doesn't write or read any data, but catches many kinds of errors that would be caught if the pipeline were to actually run. E.g.
- Syntax errors – e.g. invalid Python or SQL code
- Analysis errors – e.g. selecting from a table that doesn't exist or selecting a column that doesn't exist
- Graph validation errors - e.g. cyclic dependencies

## Programming with SDP in Python

SDP Python functions are defined in the `pyspark.pipelines` module. Your pipelines implemented with the Python API must import this module. It's common to alias the module to `sdp` to limit the number of characters you need to type when using its APIs.
Expand Down
36 changes: 28 additions & 8 deletions python/pyspark/pipelines/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ def run(
full_refresh: Sequence[str],
full_refresh_all: bool,
refresh: Sequence[str],
dry: bool,
) -> None:
"""Run the pipeline defined with the given spec.

Expand Down Expand Up @@ -276,6 +277,7 @@ def run(
full_refresh=full_refresh,
full_refresh_all=full_refresh_all,
refresh=refresh,
dry=dry,
)
try:
handle_pipeline_events(result_iter)
Expand Down Expand Up @@ -317,6 +319,13 @@ def parse_table_list(value: str) -> List[str]:
default=[],
)

# "dry-run" subcommand
dry_run_parser = subparsers.add_parser(
"dry-run",
help="Launch a run that just validates the graph and checks for errors.",
)
dry_run_parser.add_argument("--spec", help="Path to the pipeline spec.")

# "init" subcommand
init_parser = subparsers.add_parser(
"init",
Expand All @@ -330,9 +339,9 @@ def parse_table_list(value: str) -> List[str]:
)

args = parser.parse_args()
assert args.command in ["run", "init"]
assert args.command in ["run", "dry-run", "init"]

if args.command == "run":
if args.command in ["run", "dry-run"]:
if args.spec is not None:
spec_path = Path(args.spec)
if not spec_path.is_file():
Expand All @@ -343,11 +352,22 @@ def parse_table_list(value: str) -> List[str]:
else:
spec_path = find_pipeline_spec(Path.cwd())

run(
spec_path=spec_path,
full_refresh=args.full_refresh,
full_refresh_all=args.full_refresh_all,
refresh=args.refresh,
)
if args.command == "run":
run(
spec_path=spec_path,
full_refresh=args.full_refresh,
full_refresh_all=args.full_refresh_all,
refresh=args.refresh,
dry=args.command == "dry-run",
)
else:
assert args.command == "dry-run"
run(
spec_path=spec_path,
full_refresh=[],
full_refresh_all=False,
refresh=[],
dry=True,
)
elif args.command == "init":
init(args.name)
8 changes: 5 additions & 3 deletions python/pyspark/pipelines/spark_connect_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,10 @@ def handle_pipeline_events(iter: Iterator[Dict[str, Any]]) -> None:
def start_run(
spark: SparkSession,
dataflow_graph_id: str,
full_refresh: Optional[Sequence[str]] = None,
full_refresh_all: bool = False,
refresh: Optional[Sequence[str]] = None,
full_refresh: Optional[Sequence[str]],
full_refresh_all: bool,
refresh: Optional[Sequence[str]],
dry: bool,
) -> Iterator[Dict[str, Any]]:
"""Start a run of the dataflow graph in the Spark Connect server.

Expand All @@ -84,6 +85,7 @@ def start_run(
full_refresh_selection=full_refresh or [],
full_refresh_all=full_refresh_all,
refresh_selection=refresh or [],
dry=dry,
)
command = pb2.Command()
command.pipeline_command.start_run.CopyFrom(inner_command)
Expand Down
3 changes: 3 additions & 0 deletions python/pyspark/pipelines/tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,7 @@ def test_full_refresh_all_conflicts_with_full_refresh(self):
full_refresh=["table1", "table2"],
full_refresh_all=True,
refresh=[],
dry=False,
)

self.assertEqual(
Expand All @@ -396,6 +397,7 @@ def test_full_refresh_all_conflicts_with_refresh(self):
full_refresh=[],
full_refresh_all=True,
refresh=["table1", "table2"],
dry=False,
)

self.assertEqual(
Expand All @@ -421,6 +423,7 @@ def test_full_refresh_all_conflicts_with_both(self):
full_refresh=["table1"],
full_refresh_all=True,
refresh=["table2"],
dry=False,
)

self.assertEqual(
Expand Down
97 changes: 97 additions & 0 deletions python/pyspark/pipelines/tests/test_spark_connect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""
Tests that run Pipelines against a Spark Connect server.
"""

import unittest

from pyspark.errors.exceptions.connect import AnalysisException
from pyspark.pipelines.graph_element_registry import graph_element_registration_context
from pyspark.pipelines.spark_connect_graph_element_registry import (
SparkConnectGraphElementRegistry,
)
from pyspark.pipelines.spark_connect_pipeline import (
create_dataflow_graph,
start_run,
handle_pipeline_events,
)
from pyspark import pipelines as sdp
from pyspark.testing.connectutils import (
ReusedConnectTestCase,
should_test_connect,
connect_requirement_message,
)


@unittest.skipIf(not should_test_connect, connect_requirement_message)
class SparkConnectPipelinesTest(ReusedConnectTestCase):
def test_dry_run(self):
dataflow_graph_id = create_dataflow_graph(self.spark, None, None, None)
registry = SparkConnectGraphElementRegistry(self.spark, dataflow_graph_id)

with graph_element_registration_context(registry):

@sdp.materialized_view
def mv():
return self.spark.range(1)

result_iter = start_run(
self.spark,
dataflow_graph_id,
full_refresh=None,
refresh=None,
full_refresh_all=False,
dry=True,
)
handle_pipeline_events(result_iter)

def test_dry_run_failure(self):
dataflow_graph_id = create_dataflow_graph(self.spark, None, None, None)
registry = SparkConnectGraphElementRegistry(self.spark, dataflow_graph_id)

with graph_element_registration_context(registry):

@sdp.table
def st():
# Invalid because a streaming query is expected
return self.spark.range(1)

result_iter = start_run(
self.spark,
dataflow_graph_id,
full_refresh=None,
refresh=None,
full_refresh_all=False,
dry=True,
)
with self.assertRaises(AnalysisException) as context:
handle_pipeline_events(result_iter)
self.assertIn(
"INVALID_FLOW_QUERY_TYPE.BATCH_RELATION_FOR_STREAMING_TABLE", str(context.exception)
)


if __name__ == "__main__":
try:
import xmlrunner # type: ignore

testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)
Loading