Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ default boolean isInternalError() {
return SparkThrowableHelper.isInternalError(this.getCondition());
}

// If null, the error message is not for a breaking change
default BreakingChangeInfo getBreakingChangeInfo() {
return SparkThrowableHelper.getBreakingChangeInfo(
this.getCondition()).getOrElse(() -> null);
}

default Map<String, String> getMessageParameters() {
return new HashMap<>();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,22 @@ class ErrorClassesJsonReader(jsonFileURLs: Seq[URL]) {
matches.map(m => m.stripSuffix(">").stripPrefix("<"))
}

def getBreakingChangeInfo(errorClass: String): Option[BreakingChangeInfo] = {
val errorClasses = errorClass.split('.')
errorClasses match {
case Array(mainClass) =>
errorInfoMap.get(mainClass).flatMap(_.breakingChangeInfo)
case Array(mainClass, subClass) =>
errorInfoMap.get(mainClass).flatMap{
errorInfo =>
errorInfo.subClass.flatMap(_.get(subClass))
.flatMap(_.breakingChangeInfo)
.orElse(errorInfo.breakingChangeInfo)
}
case _ => None
}
}

def getMessageTemplate(errorClass: String): String = {
val errorClasses = errorClass.split("\\.")
assert(errorClasses.length == 1 || errorClasses.length == 2)
Expand Down Expand Up @@ -128,7 +144,7 @@ private object ErrorClassesJsonReader {
val map = mapper.readValue(url, new TypeReference[Map[String, ErrorInfo]]() {})
val errorClassWithDots = map.collectFirst {
case (errorClass, _) if errorClass.contains('.') => errorClass
case (_, ErrorInfo(_, Some(map), _)) if map.keys.exists(_.contains('.')) =>
case (_, ErrorInfo(_, Some(map), _, _)) if map.keys.exists(_.contains('.')) =>
map.keys.collectFirst { case s if s.contains('.') => s }.get
}
if (errorClassWithDots.isEmpty) {
Expand All @@ -147,28 +163,59 @@ private object ErrorClassesJsonReader {
* @param subClass SubClass associated with this class.
* @param message Message format with optional placeholders (e.g. &lt;parm&gt;).
* The error message is constructed by concatenating the lines with newlines.
* @param breakingChangeInfo Additional metadata if the error is due to a breaking change.
*/
private case class ErrorInfo(
message: Seq[String],
subClass: Option[Map[String, ErrorSubInfo]],
sqlState: Option[String]) {
sqlState: Option[String],
breakingChangeInfo: Option[BreakingChangeInfo] = None) {
// For compatibility with multi-line error messages
@JsonIgnore
val messageTemplate: String = message.mkString("\n")
val messageTemplate: String = message.mkString("\n") +
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: shall we add an \n between the main error message and the breaking change message?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it makes sense to use a space

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But this may look weird as the migration message itself can be multi lines

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I based this code on the existing logic for joining the subclass message:

errorInfo.messageTemplate + " " + errorSubInfo.messageTemplate

That logic uses a space so I think it makes sense to match that for consistency.

In the common case where the message is a single line, I think a newline is more confusing than a space.

breakingChangeInfo.map(_.migrationMessage.mkString(" ", "\n", "")).getOrElse("")
}

/**
* Information associated with an error subclass.
*
* @param message Message format with optional placeholders (e.g. &lt;parm&gt;).
* The error message is constructed by concatenating the lines with newlines.
* @param breakingChangeInfo Additional metadata if the error is due to a breaking change.
*/
private case class ErrorSubInfo(message: Seq[String]) {
private case class ErrorSubInfo(
message: Seq[String],
breakingChangeInfo: Option[BreakingChangeInfo] = None) {
// For compatibility with multi-line error messages
@JsonIgnore
val messageTemplate: String = message.mkString("\n")
val messageTemplate: String = message.mkString("\n") +
breakingChangeInfo.map(_.migrationMessage.mkString(" ", "\n", "")).getOrElse("")
}

/**
* Additional information if the error was caused by a breaking change.
*
* @param migrationMessage A message explaining how the user can migrate their job to work
* with the breaking change.
* @param mitigationConfig A spark config flag that can be used to mitigate the
* breaking change.
* @param needsAudit If true, the breaking change should be inspected manually.
* If false, the spark job should be retried by setting the
* mitigationConfig.
*/
case class BreakingChangeInfo(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just realize that we expose it as a public API via SparkThrowable.getBreakingChangeInfo. We shouldn't expose a case class as public API as it has a wide API surface, including the companion object.

We should follow SparkThrowable and define it in Java.

migrationMessage: Seq[String],
mitigationConfig: Option[MitigationConfig] = None,
needsAudit: Boolean = true
)

/**
* A spark config flag that can be used to mitigate a breaking change.
* @param key The spark config key.
* @param value The spark config value that mitigates the breaking change.
*/
case class MitigationConfig(key: String, value: String)

/**
* Information associated with an error state / SQLSTATE.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,14 @@ private[spark] object SparkThrowableHelper {
errorReader.getMessageParameters(errorClass)
}

def getBreakingChangeInfo(errorClass: String): Option[BreakingChangeInfo] = {
if (errorClass == null) {
None
} else {
errorReader.getBreakingChangeInfo(errorClass)
}
}

def isInternalError(errorClass: String): Boolean = {
errorClass != null && errorClass.startsWith("INTERNAL_ERROR")
}
Expand All @@ -99,6 +107,19 @@ private[spark] object SparkThrowableHelper {
g.writeStringField("errorClass", errorClass)
if (format == STANDARD) {
g.writeStringField("messageTemplate", errorReader.getMessageTemplate(errorClass))
errorReader.getBreakingChangeInfo(errorClass).foreach { breakingChangeInfo =>
g.writeObjectFieldStart("breakingChangeInfo")
g.writeStringField("migrationMessage",
breakingChangeInfo.migrationMessage.mkString("\n"))
breakingChangeInfo.mitigationConfig.foreach { mitigationConfig =>
g.writeObjectFieldStart("mitigationConfig")
g.writeStringField("key", mitigationConfig.key)
g.writeStringField("value", mitigationConfig.value)
g.writeEndObject()
}
g.writeBooleanField("needsAudit", breakingChangeInfo.needsAudit)
g.writeEndObject()
}
}
val sqlState = e.getSqlState
if (sqlState != null) g.writeStringField("sqlState", sqlState)
Expand Down
84 changes: 84 additions & 0 deletions core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,90 @@ class SparkThrowableSuite extends SparkFunSuite {
}
}

test("breaking changes info") {
assert(SparkThrowableHelper.getBreakingChangeInfo(null).isEmpty)

val nonBreakingChangeError = new SparkException(
errorClass = "CANNOT_PARSE_DECIMAL",
messageParameters = Map.empty[String, String],
cause = null)
assert(nonBreakingChangeError.getBreakingChangeInfo == null)

withTempDir { dir =>
val json = new File(dir, "errors.json")
Files.writeString(
json.toPath,
"""
|{
| "TEST_ERROR": {
| "message": [
| "Error message 1 with <param1>."
| ],
| "breakingChangeInfo": {
| "migrationMessage": [
| "Migration message with <param2>."
| ],
| "mitigationConfig": {
| "key": "config.key1",
| "value": "config.value1"
| },
| "needsAudit": false
| }
| },
| "TEST_ERROR_WITH_SUBCLASS": {
| "message": [
| "Error message 2 with <param1>."
| ],
| "subClass": {
| "SUBCLASS": {
| "message": [
| "Subclass message with <param2>."
| ],
| "breakingChangeInfo": {
| "migrationMessage": [
| "Subclass migration message with <param3>."
| ],
| "mitigationConfig": {
| "key": "config.key2",
| "value": "config.value2"
| },
| "needsAudit": true
| }
| }
| }
| }
|}
|""".stripMargin,
StandardCharsets.UTF_8)

val error1Params = Map("param1" -> "value1", "param2" -> "value2")
val error2Params = Map("param1" -> "value1", "param2" -> "value2", "param3" -> "value3")

val reader =
new ErrorClassesJsonReader(Seq(errorJsonFilePath.toUri.toURL, json.toURI.toURL))
val errorMessage = reader.getErrorMessage("TEST_ERROR", error1Params)
assert(errorMessage == "Error message 1 with value1. Migration message with value2.")
val breakingChangeInfo = reader.getBreakingChangeInfo("TEST_ERROR")
assert(
breakingChangeInfo.contains(
BreakingChangeInfo(
Seq("Migration message with <param2>."),
Some(MitigationConfig("config.key1", "config.value1")),
needsAudit = false)))
val errorMessage2 =
reader.getErrorMessage("TEST_ERROR_WITH_SUBCLASS.SUBCLASS", error2Params)
assert(
errorMessage2 == "Error message 2 with value1. Subclass message with value2." +
" Subclass migration message with value3.")
val breakingChangeInfo2 = reader.getBreakingChangeInfo("TEST_ERROR_WITH_SUBCLASS.SUBCLASS")
assert(
breakingChangeInfo2.contains(
BreakingChangeInfo(
Seq("Subclass migration message with <param3>."),
Some(MitigationConfig("config.key2", "config.value2")))))
}
}

test("detect unused message parameters") {
checkError(
exception = intercept[SparkException] {
Expand Down
19 changes: 18 additions & 1 deletion python/pyspark/errors/exceptions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import warnings
from abc import ABC, abstractmethod
from enum import Enum
from typing import Dict, Optional, TypeVar, cast, Iterable, TYPE_CHECKING, List
from typing import Any, Dict, Optional, TypeVar, cast, Iterable, TYPE_CHECKING, List

from pyspark.errors.exceptions.tblib import Traceback
from pyspark.errors.utils import ErrorClassesReader
Expand Down Expand Up @@ -138,6 +138,23 @@ def getMessage(self) -> str:
"""
return f"[{self.getCondition()}] {self._message}"

def getBreakingChangeInfo(self) -> Optional[Dict[str, Any]]:
"""
Returns the breaking change info for an error, or None.

Breaking change info is a dict with two fields:

migration_message: list of str
A message explaining how the user can migrate their job to work
with the breaking change.

mitigation_config:
A dict with key: str and value: str fields.
A spark config flag that can be used to mitigate the
breaking change.
"""
return self._error_reader.get_breaking_change_info(self._errorClass)

def getQueryContext(self) -> List["QueryContext"]:
"""
Returns :class:`QueryContext`.
Expand Down
32 changes: 31 additions & 1 deletion python/pyspark/errors/exceptions/connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import grpc
import json
from grpc import StatusCode
from typing import Dict, List, Optional, TYPE_CHECKING
from typing import Any, Dict, List, Optional, TYPE_CHECKING

from pyspark.errors.exceptions.base import (
AnalysisException as BaseAnalysisException,
Expand Down Expand Up @@ -95,6 +95,7 @@ def _convert_exception(
display_server_stacktrace = display_server_stacktrace if stacktrace else False

contexts = None
breaking_change_info = None
if resp and resp.HasField("root_error_idx"):
root_error = resp.errors[resp.root_error_idx]
if hasattr(root_error, "spark_throwable"):
Expand All @@ -105,6 +106,20 @@ def _convert_exception(
else DataFrameQueryContext(c)
for c in root_error.spark_throwable.query_contexts
]
# Extract breaking change info if present
if hasattr(
root_error.spark_throwable, "breaking_change_info"
) and root_error.spark_throwable.HasField("breaking_change_info"):
bci = root_error.spark_throwable.breaking_change_info
breaking_change_info = {
"migration_message": list(bci.migration_message),
"needs_audit": bci.needs_audit if bci.HasField("needs_audit") else True,
}
if bci.HasField("mitigation_config"):
breaking_change_info["mitigation_config"] = {
"key": bci.mitigation_config.key,
"value": bci.mitigation_config.value,
}

if "org.apache.spark.api.python.PythonException" in classes:
return PythonException(
Expand Down Expand Up @@ -134,6 +149,7 @@ def _convert_exception(
display_server_stacktrace=display_server_stacktrace,
contexts=contexts,
grpc_status_code=grpc_status_code,
breaking_change_info=breaking_change_info,
)

# Return UnknownException if there is no matched exception class
Expand All @@ -147,6 +163,7 @@ def _convert_exception(
display_server_stacktrace=display_server_stacktrace,
contexts=contexts,
grpc_status_code=grpc_status_code,
breaking_change_info=breaking_change_info,
)


Expand Down Expand Up @@ -193,6 +210,7 @@ def __init__(
display_server_stacktrace: bool = False,
contexts: Optional[List[BaseQueryContext]] = None,
grpc_status_code: grpc.StatusCode = StatusCode.UNKNOWN,
breaking_change_info: Optional[Dict[str, Any]] = None,
) -> None:
if contexts is None:
contexts = []
Expand Down Expand Up @@ -221,6 +239,7 @@ def __init__(
self._display_stacktrace: bool = display_server_stacktrace
self._contexts: List[BaseQueryContext] = contexts
self._grpc_status_code = grpc_status_code
self._breaking_change_info: Optional[Dict[str, Any]] = breaking_change_info
self._log_exception()

def getSqlState(self) -> Optional[str]:
Expand All @@ -241,6 +260,15 @@ def getMessage(self) -> str:
def getGrpcStatusCode(self) -> grpc.StatusCode:
return self._grpc_status_code

def getBreakingChangeInfo(self) -> Optional[Dict[str, Any]]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we add a BreakingChangeInfo class in python as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a proto class defined in sql/connect/common/src/main/protobuf/spark/connect/base.proto but I didn't want to introduce that as a dependency here

"""
Returns the breaking change info for an error, or None.

For Spark Connect exceptions, this returns the breaking change info
received from the server, rather than looking it up from local error files.
"""
return self._breaking_change_info

def __str__(self) -> str:
return self.getMessage()

Expand All @@ -263,6 +291,7 @@ def __init__(
display_server_stacktrace: bool = False,
contexts: Optional[List[BaseQueryContext]] = None,
grpc_status_code: grpc.StatusCode = StatusCode.UNKNOWN,
breaking_change_info: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__(
message=message,
Expand All @@ -274,6 +303,7 @@ def __init__(
display_server_stacktrace=display_server_stacktrace,
contexts=contexts,
grpc_status_code=grpc_status_code,
breaking_change_info=breaking_change_info,
)


Expand Down
Loading