Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions python/pyspark/pipelines/source_code_location.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,34 @@ def get_caller_source_code_location(stacklevel: int) -> SourceCodeLocation:
"""
Returns a SourceCodeLocation object representing the location code that invokes this function.

If this function is called from a decorator (ex. @sdp.table), note that the returned line
number is affected by how the decorator was triggered - i.e. whether @sdp.table or @sdp.table()
was called - AND what python version is being used

Case 1:
|@sdp.table()
|def fn

@sdp.table() is executed immediately, on line 1. This is true for all python versions.

Case 2:
|@sdp.table
|def fn

In python < 3.10, @sdp.table will expand to fn = sdp.table(fn), replacing the line that `fn` is
defined on. This would be line 2. More interestingly, this means:

|@sdp.table
|
|
|def fn

Will expand to fn = sdp.table(fn) on line 4, where `fn` is defined.

However, in python 3.10+, the line number in the stack trace will still be the line that the
decorator was defined on. In other words, case 2 will be treated the same as case 1, and the
line number will be 1.

:param stacklevel: The number of stack frames to go up. 0 means the direct caller of this
function, 1 means the caller of the caller, and so on.
"""
Expand Down
11 changes: 11 additions & 0 deletions python/pyspark/pipelines/spark_connect_graph_element_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
)
from pyspark.pipelines.flow import Flow
from pyspark.pipelines.graph_element_registry import GraphElementRegistry
from pyspark.pipelines.source_code_location import SourceCodeLocation
from typing import Any, cast
import pyspark.sql.connect.proto as pb2

Expand Down Expand Up @@ -79,6 +80,7 @@ def register_dataset(self, dataset: Dataset) -> None:
partition_cols=partition_cols,
schema=schema,
format=format,
source_code_location=source_code_location_to_proto(dataset.source_code_location),
)
command = pb2.Command()
command.pipeline_command.define_dataset.CopyFrom(inner_command)
Expand All @@ -95,6 +97,7 @@ def register_flow(self, flow: Flow) -> None:
target_dataset_name=flow.target,
relation=relation,
sql_conf=flow.spark_conf,
source_code_location=source_code_location_to_proto(flow.source_code_location),
)
command = pb2.Command()
command.pipeline_command.define_flow.CopyFrom(inner_command)
Expand All @@ -109,3 +112,11 @@ def register_sql(self, sql_text: str, file_path: Path) -> None:
command = pb2.Command()
command.pipeline_command.define_sql_graph_elements.CopyFrom(inner_command)
self._client.execute_command(command)


def source_code_location_to_proto(
source_code_location: SourceCodeLocation,
) -> pb2.SourceCodeLocation:
return pb2.SourceCodeLocation(
file_name=source_code_location.filename, line_number=source_code_location.line_number
)
68 changes: 35 additions & 33 deletions python/pyspark/sql/connect/proto/pipelines_pb2.py

Large diffs are not rendered by default.

94 changes: 94 additions & 0 deletions python/pyspark/sql/connect/proto/pipelines_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ class PipelineCommand(google.protobuf.message.Message):
PARTITION_COLS_FIELD_NUMBER: builtins.int
SCHEMA_FIELD_NUMBER: builtins.int
FORMAT_FIELD_NUMBER: builtins.int
SOURCE_CODE_LOCATION_FIELD_NUMBER: builtins.int
dataflow_graph_id: builtins.str
"""The graph to attach this dataset to."""
dataset_name: builtins.str
Expand Down Expand Up @@ -260,6 +261,9 @@ class PipelineCommand(google.protobuf.message.Message):
"""The output table format of the dataset. Only applies to dataset_type == TABLE and
dataset_type == MATERIALIZED_VIEW.
"""
@property
def source_code_location(self) -> global___SourceCodeLocation:
"""The location in source code that this dataset was defined."""
def __init__(
self,
*,
Expand All @@ -271,6 +275,7 @@ class PipelineCommand(google.protobuf.message.Message):
partition_cols: collections.abc.Iterable[builtins.str] | None = ...,
schema: pyspark.sql.connect.proto.types_pb2.DataType | None = ...,
format: builtins.str | None = ...,
source_code_location: global___SourceCodeLocation | None = ...,
) -> None: ...
def HasField(
self,
Expand All @@ -287,6 +292,8 @@ class PipelineCommand(google.protobuf.message.Message):
b"_format",
"_schema",
b"_schema",
"_source_code_location",
b"_source_code_location",
"comment",
b"comment",
"dataflow_graph_id",
Expand All @@ -299,6 +306,8 @@ class PipelineCommand(google.protobuf.message.Message):
b"format",
"schema",
b"schema",
"source_code_location",
b"source_code_location",
],
) -> builtins.bool: ...
def ClearField(
Expand All @@ -316,6 +325,8 @@ class PipelineCommand(google.protobuf.message.Message):
b"_format",
"_schema",
b"_schema",
"_source_code_location",
b"_source_code_location",
"comment",
b"comment",
"dataflow_graph_id",
Expand All @@ -330,6 +341,8 @@ class PipelineCommand(google.protobuf.message.Message):
b"partition_cols",
"schema",
b"schema",
"source_code_location",
b"source_code_location",
"table_properties",
b"table_properties",
],
Expand Down Expand Up @@ -359,6 +372,13 @@ class PipelineCommand(google.protobuf.message.Message):
def WhichOneof(
self, oneof_group: typing_extensions.Literal["_schema", b"_schema"]
) -> typing_extensions.Literal["schema"] | None: ...
@typing.overload
def WhichOneof(
self,
oneof_group: typing_extensions.Literal[
"_source_code_location", b"_source_code_location"
],
) -> typing_extensions.Literal["source_code_location"] | None: ...

class DefineFlow(google.protobuf.message.Message):
"""Request to define a flow targeting a dataset."""
Expand Down Expand Up @@ -415,6 +435,7 @@ class PipelineCommand(google.protobuf.message.Message):
RELATION_FIELD_NUMBER: builtins.int
SQL_CONF_FIELD_NUMBER: builtins.int
CLIENT_ID_FIELD_NUMBER: builtins.int
SOURCE_CODE_LOCATION_FIELD_NUMBER: builtins.int
dataflow_graph_id: builtins.str
"""The graph to attach this flow to."""
flow_name: builtins.str
Expand All @@ -435,6 +456,9 @@ class PipelineCommand(google.protobuf.message.Message):
"""Identifier for the client making the request. The server uses this to determine what flow
evaluation request stream to dispatch evaluation requests to for this flow.
"""
@property
def source_code_location(self) -> global___SourceCodeLocation:
"""The location in source code that this flow was defined."""
def __init__(
self,
*,
Expand All @@ -444,6 +468,7 @@ class PipelineCommand(google.protobuf.message.Message):
relation: pyspark.sql.connect.proto.relations_pb2.Relation | None = ...,
sql_conf: collections.abc.Mapping[builtins.str, builtins.str] | None = ...,
client_id: builtins.str | None = ...,
source_code_location: global___SourceCodeLocation | None = ...,
) -> None: ...
def HasField(
self,
Expand All @@ -456,6 +481,8 @@ class PipelineCommand(google.protobuf.message.Message):
b"_flow_name",
"_relation",
b"_relation",
"_source_code_location",
b"_source_code_location",
"_target_dataset_name",
b"_target_dataset_name",
"client_id",
Expand All @@ -466,6 +493,8 @@ class PipelineCommand(google.protobuf.message.Message):
b"flow_name",
"relation",
b"relation",
"source_code_location",
b"source_code_location",
"target_dataset_name",
b"target_dataset_name",
],
Expand All @@ -481,6 +510,8 @@ class PipelineCommand(google.protobuf.message.Message):
b"_flow_name",
"_relation",
b"_relation",
"_source_code_location",
b"_source_code_location",
"_target_dataset_name",
b"_target_dataset_name",
"client_id",
Expand All @@ -491,6 +522,8 @@ class PipelineCommand(google.protobuf.message.Message):
b"flow_name",
"relation",
b"relation",
"source_code_location",
b"source_code_location",
"sql_conf",
b"sql_conf",
"target_dataset_name",
Expand All @@ -515,6 +548,13 @@ class PipelineCommand(google.protobuf.message.Message):
self, oneof_group: typing_extensions.Literal["_relation", b"_relation"]
) -> typing_extensions.Literal["relation"] | None: ...
@typing.overload
def WhichOneof(
self,
oneof_group: typing_extensions.Literal[
"_source_code_location", b"_source_code_location"
],
) -> typing_extensions.Literal["source_code_location"] | None: ...
@typing.overload
def WhichOneof(
self,
oneof_group: typing_extensions.Literal["_target_dataset_name", b"_target_dataset_name"],
Expand Down Expand Up @@ -1134,6 +1174,60 @@ class PipelineEvent(google.protobuf.message.Message):

global___PipelineEvent = PipelineEvent

class SourceCodeLocation(google.protobuf.message.Message):
"""Source code location information associated with a particular dataset or flow."""

DESCRIPTOR: google.protobuf.descriptor.Descriptor

FILE_NAME_FIELD_NUMBER: builtins.int
LINE_NUMBER_FIELD_NUMBER: builtins.int
file_name: builtins.str
"""The file that this pipeline source code was defined in."""
line_number: builtins.int
"""The specific line number that this pipeline source code is located at, if applicable."""
def __init__(
self,
*,
file_name: builtins.str | None = ...,
line_number: builtins.int | None = ...,
) -> None: ...
def HasField(
self,
field_name: typing_extensions.Literal[
"_file_name",
b"_file_name",
"_line_number",
b"_line_number",
"file_name",
b"file_name",
"line_number",
b"line_number",
],
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
"_file_name",
b"_file_name",
"_line_number",
b"_line_number",
"file_name",
b"file_name",
"line_number",
b"line_number",
],
) -> None: ...
@typing.overload
def WhichOneof(
self, oneof_group: typing_extensions.Literal["_file_name", b"_file_name"]
) -> typing_extensions.Literal["file_name"] | None: ...
@typing.overload
def WhichOneof(
self, oneof_group: typing_extensions.Literal["_line_number", b"_line_number"]
) -> typing_extensions.Literal["line_number"] | None: ...

global___SourceCodeLocation = SourceCodeLocation

class PipelineQueryFunctionExecutionSignal(google.protobuf.message.Message):
"""A signal from the server to the client to execute the query function for one or more flows, and
to register their results with the server.
Expand Down
14 changes: 14 additions & 0 deletions sql/connect/common/src/main/protobuf/spark/connect/pipelines.proto
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ message PipelineCommand {
// The output table format of the dataset. Only applies to dataset_type == TABLE and
// dataset_type == MATERIALIZED_VIEW.
optional string format = 8;

// The location in source code that this dataset was defined.
optional SourceCodeLocation source_code_location = 9;
}

// Request to define a flow targeting a dataset.
Expand All @@ -110,6 +113,9 @@ message PipelineCommand {
// evaluation request stream to dispatch evaluation requests to for this flow.
optional string client_id = 6;

// The location in source code that this flow was defined.
optional SourceCodeLocation source_code_location = 7;

message Response {
// Fully qualified flow name that uniquely identify a flow in the Dataflow graph.
optional string flow_name = 1;
Expand Down Expand Up @@ -217,6 +223,14 @@ message PipelineEvent {
optional string message = 2;
}

// Source code location information associated with a particular dataset or flow.
message SourceCodeLocation {
// The file that this pipeline source code was defined in.
optional string file_name = 1;
// The specific line number that this pipeline source code is located at, if applicable.
optional int32 line_number = 2;
}

// A signal from the server to the client to execute the query function for one or more flows, and
// to register their results with the server.
message PipelineQueryFunctionExecutionSignal {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,11 @@ private[connect] object PipelinesHandler extends Logging {
partitionCols = Option(dataset.getPartitionColsList.asScala.toSeq)
.filter(_.nonEmpty),
properties = dataset.getTablePropertiesMap.asScala.toMap,
baseOrigin = QueryOrigin(
origin = QueryOrigin(
filePath = Option.when(dataset.getSourceCodeLocation.hasFileName)(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we can store filePath and line to variables or a method to avoid duplicated code

dataset.getSourceCodeLocation.getFileName),
line = Option.when(dataset.getSourceCodeLocation.hasLineNumber)(
dataset.getSourceCodeLocation.getLineNumber),
objectType = Option(QueryOriginType.Table.toString),
objectName = Option(qualifiedIdentifier.unquotedString),
language = Option(Python())),
Expand All @@ -212,6 +216,10 @@ private[connect] object PipelinesHandler extends Logging {
identifier = viewIdentifier,
comment = Option(dataset.getComment),
origin = QueryOrigin(
filePath = Option.when(dataset.getSourceCodeLocation.hasFileName)(
dataset.getSourceCodeLocation.getFileName),
line = Option.when(dataset.getSourceCodeLocation.hasLineNumber)(
dataset.getSourceCodeLocation.getLineNumber),
objectType = Option(QueryOriginType.View.toString),
objectName = Option(viewIdentifier.unquotedString),
language = Option(Python())),
Expand Down Expand Up @@ -281,6 +289,10 @@ private[connect] object PipelinesHandler extends Logging {
once = false,
queryContext = QueryContext(Option(defaultCatalog), Option(defaultDatabase)),
origin = QueryOrigin(
filePath = Option.when(flow.getSourceCodeLocation.hasFileName)(
flow.getSourceCodeLocation.getFileName),
line = Option.when(flow.getSourceCodeLocation.hasLineNumber)(
flow.getSourceCodeLocation.getLineNumber),
objectType = Option(QueryOriginType.Flow.toString),
objectName = Option(flowIdentifier.unquotedString),
language = Option(Python()))))
Expand Down
Loading