Skip to content
Closed
7 changes: 7 additions & 0 deletions python/pyspark/errors/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,13 @@
"`<func_name>` does not allow a Column in a list."
]
},
"CONFLICTING_PIPELINE_REFRESH_OPTIONS" : {
"message" : [
"--full-refresh-all option conflicts with <conflicting_option>",
"The --full-refresh-all option performs a full refresh of all datasets, ",
"so specifying individual datasets with <conflicting_option> is not allowed."
]
},
"CONNECT_URL_ALREADY_DEFINED": {
"message": [
"Only one Spark Connect client URL can be set; however, got a different URL [<new_url>] from the existing [<existing_url>]."
Expand Down
77 changes: 71 additions & 6 deletions python/pyspark/pipelines/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import yaml
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Generator, Mapping, Optional, Sequence
from typing import Any, Generator, List, Mapping, Optional, Sequence

from pyspark.errors import PySparkException, PySparkTypeError
from pyspark.sql import SparkSession
Expand Down Expand Up @@ -217,8 +217,36 @@ def change_dir(path: Path) -> Generator[None, None, None]:
os.chdir(prev)


def run(spec_path: Path) -> None:
"""Run the pipeline defined with the given spec."""
def run(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

High level question: did we consider putting refresh selection options in the pipeline spec, rather than as a CLI arg?

More generally, what's the philosophy for whether a configuration should be accepted as a CLI arg vs a pipeline spec field?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we expect it to vary across run for the same pipeline, it should be a CLI arg. If we expect it to be static for a pipeline, it should live in the spec. I would expect selections to vary across runs.

spec_path: Path,
full_refresh: Sequence[str],
full_refresh_all: bool,
refresh: Sequence[str],
) -> None:
"""Run the pipeline defined with the given spec.

:param spec_path: Path to the pipeline specification file.
:param full_refresh: List of datasets to reset and recompute.
:param full_refresh_all: Perform a full graph reset and recompute.
:param refresh: List of datasets to update.
"""
# Validate conflicting arguments
if full_refresh_all:
if full_refresh:
raise PySparkException(
errorClass="CONFLICTING_PIPELINE_REFRESH_OPTIONS",
messageParameters={
"conflicting_option": "--full_refresh",
},
)
if refresh:
raise PySparkException(
errorClass="CONFLICTING_PIPELINE_REFRESH_OPTIONS",
messageParameters={
"conflicting_option": "--refresh",
},
)

log_with_curr_timestamp(f"Loading pipeline spec from {spec_path}...")
spec = load_pipeline_spec(spec_path)

Expand All @@ -242,20 +270,52 @@ def run(spec_path: Path) -> None:
register_definitions(spec_path, registry, spec)

log_with_curr_timestamp("Starting run...")
result_iter = start_run(spark, dataflow_graph_id)
result_iter = start_run(
spark,
dataflow_graph_id,
full_refresh=full_refresh,
full_refresh_all=full_refresh_all,
refresh=refresh,
)
try:
handle_pipeline_events(result_iter)
finally:
spark.stop()


def parse_table_list(value: str) -> List[str]:
"""Parse a comma-separated list of table names, handling whitespace."""
return [table.strip() for table in value.split(",") if table.strip()]


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Pipeline CLI")
subparsers = parser.add_subparsers(dest="command", required=True)

# "run" subcommand
run_parser = subparsers.add_parser("run", help="Run a pipeline.")
run_parser = subparsers.add_parser(
"run",
help="Run a pipeline. If no refresh options specified, "
"a default incremental update is performed.",
)
run_parser.add_argument("--spec", help="Path to the pipeline spec.")
run_parser.add_argument(
"--full-refresh",
type=parse_table_list,
action="extend",
help="List of datasets to reset and recompute (comma-separated).",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here and below, should we document default behavior if this arg is not specified at all?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will extend split using commas?

default=[],
)
run_parser.add_argument(
"--full-refresh-all", action="store_true", help="Perform a full graph reset and recompute."
)
run_parser.add_argument(
"--refresh",
type=parse_table_list,
action="extend",
help="List of datasets to update (comma-separated).",
default=[],
)

# "init" subcommand
init_parser = subparsers.add_parser(
Expand Down Expand Up @@ -283,6 +343,11 @@ def run(spec_path: Path) -> None:
else:
spec_path = find_pipeline_spec(Path.cwd())

run(spec_path=spec_path)
run(
spec_path=spec_path,
full_refresh=args.full_refresh,
full_refresh_all=args.full_refresh_all,
refresh=args.refresh,
)
elif args.command == "init":
init(args.name)
20 changes: 17 additions & 3 deletions python/pyspark/pipelines/spark_connect_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# limitations under the License.
#
from datetime import timezone
from typing import Any, Dict, Mapping, Iterator, Optional, cast
from typing import Any, Dict, Mapping, Iterator, Optional, cast, Sequence

import pyspark.sql.connect.proto as pb2
from pyspark.sql import SparkSession
Expand Down Expand Up @@ -65,12 +65,26 @@ def handle_pipeline_events(iter: Iterator[Dict[str, Any]]) -> None:
log_with_provided_timestamp(event.message, dt)


def start_run(spark: SparkSession, dataflow_graph_id: str) -> Iterator[Dict[str, Any]]:
def start_run(
spark: SparkSession,
dataflow_graph_id: str,
full_refresh: Optional[Sequence[str]] = None,
full_refresh_all: bool = False,
refresh: Optional[Sequence[str]] = None,
) -> Iterator[Dict[str, Any]]:
"""Start a run of the dataflow graph in the Spark Connect server.

:param dataflow_graph_id: The ID of the dataflow graph to start.
:param full_refresh: List of datasets to reset and recompute.
:param full_refresh_all: Perform a full graph reset and recompute.
:param refresh: List of datasets to update.
"""
inner_command = pb2.PipelineCommand.StartRun(dataflow_graph_id=dataflow_graph_id)
inner_command = pb2.PipelineCommand.StartRun(
dataflow_graph_id=dataflow_graph_id,
full_refresh_selection=full_refresh or [],
full_refresh_all=full_refresh_all,
refresh_selection=refresh or [],
)
command = pb2.Command()
command.pipeline_command.start_run.CopyFrom(inner_command)
# Cast because mypy seems to think `spark`` is a function, not an object. Likely related to
Expand Down
90 changes: 90 additions & 0 deletions python/pyspark/pipelines/tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
unpack_pipeline_spec,
DefinitionsGlob,
PipelineSpec,
run,
)
from pyspark.pipelines.tests.local_graph_element_registry import LocalGraphElementRegistry

Expand Down Expand Up @@ -358,6 +359,95 @@ def test_python_import_current_directory(self):
),
)

def test_full_refresh_all_conflicts_with_full_refresh(self):
with tempfile.TemporaryDirectory() as temp_dir:
# Create a minimal pipeline spec
spec_path = Path(temp_dir) / "pipeline.yaml"
with spec_path.open("w") as f:
f.write('{"name": "test_pipeline"}')

# Test that providing both --full-refresh-all and --full-refresh raises an exception
with self.assertRaises(PySparkException) as context:
run(
spec_path=spec_path,
full_refresh=["table1", "table2"],
full_refresh_all=True,
refresh=[],
)

self.assertEqual(
context.exception.getCondition(), "CONFLICTING_PIPELINE_REFRESH_OPTIONS"
)
self.assertEqual(
context.exception.getMessageParameters(), {"conflicting_option": "--full_refresh"}
)

def test_full_refresh_all_conflicts_with_refresh(self):
with tempfile.TemporaryDirectory() as temp_dir:
# Create a minimal pipeline spec
spec_path = Path(temp_dir) / "pipeline.yaml"
with spec_path.open("w") as f:
f.write('{"name": "test_pipeline"}')

# Test that providing both --full-refresh-all and --refresh raises an exception
with self.assertRaises(PySparkException) as context:
run(
spec_path=spec_path,
full_refresh=[],
full_refresh_all=True,
refresh=["table1", "table2"],
)

self.assertEqual(
context.exception.getCondition(), "CONFLICTING_PIPELINE_REFRESH_OPTIONS"
)
self.assertEqual(
context.exception.getMessageParameters(),
{"conflicting_option": "--refresh"},
)

def test_full_refresh_all_conflicts_with_both(self):
with tempfile.TemporaryDirectory() as temp_dir:
# Create a minimal pipeline spec
spec_path = Path(temp_dir) / "pipeline.yaml"
with spec_path.open("w") as f:
f.write('{"name": "test_pipeline"}')

# Test that providing --full-refresh-all with both other options raises an exception
# (it should catch the first conflict - full_refresh)
with self.assertRaises(PySparkException) as context:
run(
spec_path=spec_path,
full_refresh=["table1"],
full_refresh_all=True,
refresh=["table2"],
)

self.assertEqual(
context.exception.getCondition(), "CONFLICTING_PIPELINE_REFRESH_OPTIONS"
)

def test_parse_table_list_single_table(self):
"""Test parsing a single table name."""
from pyspark.pipelines.cli import parse_table_list

result = parse_table_list("table1")
self.assertEqual(result, ["table1"])

def test_parse_table_list_multiple_tables(self):
"""Test parsing multiple table names."""
from pyspark.pipelines.cli import parse_table_list

result = parse_table_list("table1,table2,table3")
self.assertEqual(result, ["table1", "table2", "table3"])

def test_parse_table_list_with_spaces(self):
"""Test parsing table names with spaces."""
from pyspark.pipelines.cli import parse_table_list

result = parse_table_list("table1, table2 , table3")
self.assertEqual(result, ["table1", "table2", "table3"])


if __name__ == "__main__":
try:
Expand Down
Loading