diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 5d8b714711774..87bfbdf64a49f 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -31,9 +31,10 @@ class Module(object): files have changed. """ - def __init__(self, name, dependencies, source_file_regexes, build_profile_flags=(), environ={}, - sbt_test_goals=(), python_test_goals=(), excluded_python_implementations=(), - test_tags=(), should_run_r_tests=False, should_run_build_tests=False): + def __init__(self, name, dependencies, source_file_regexes, build_profile_flags=(), + environ=None, sbt_test_goals=(), python_test_goals=(), + excluded_python_implementations=(), test_tags=(), should_run_r_tests=False, + should_run_build_tests=False): """ Define a new module. @@ -62,7 +63,7 @@ def __init__(self, name, dependencies, source_file_regexes, build_profile_flags= self.source_file_prefixes = source_file_regexes self.sbt_test_goals = sbt_test_goals self.build_profile_flags = build_profile_flags - self.environ = environ + self.environ = environ or {} self.python_test_goals = python_test_goals self.excluded_python_implementations = excluded_python_implementations self.test_tags = test_tags diff --git a/dev/tox.ini b/dev/tox.ini index 7edf7d597fb58..43cd5877dfdb8 100644 --- a/dev/tox.ini +++ b/dev/tox.ini @@ -19,6 +19,6 @@ max-line-length=100 exclude=python/pyspark/cloudpickle/*.py,shared.py,python/docs/source/conf.py,work/*/*.py,python/.eggs/*,dist/*,.git/* [flake8] -select = E901,E999,F821,F822,F823,F401,F405 +select = E901,E999,F821,F822,F823,F401,F405,B006 exclude = python/pyspark/cloudpickle/*.py,shared.py*,python/docs/source/conf.py,work/*/*.py,python/.eggs/*,dist/*,.git/*,python/out,python/pyspark/sql/pandas/functions.pyi,python/pyspark/sql/column.pyi,python/pyspark/worker.pyi,python/pyspark/java_gateway.pyi max-line-length = 100 diff --git a/python/mypy.ini b/python/mypy.ini index 5103452a053be..ad4fcf7f317f0 100644 --- a/python/mypy.ini +++ b/python/mypy.ini @@ -102,6 +102,8 @@ disallow_untyped_defs = False ; Ignore errors in embedded third party code +no_implicit_optional = True + [mypy-pyspark.cloudpickle.*] ignore_errors = True diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index d37654a7388f5..8ecb68458ffbc 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -1801,7 +1801,7 @@ class AFTSurvivalRegression(_JavaRegressor, _AFTSurvivalRegressionParams, @keyword_only def __init__(self, *, featuresCol="features", labelCol="label", predictionCol="prediction", fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", - quantileProbabilities=list([0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]), + quantileProbabilities=list([0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]), # noqa: B005 quantilesCol=None, aggregationDepth=2, maxBlockSizeInMB=0.0): """ __init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \ @@ -1819,7 +1819,7 @@ def __init__(self, *, featuresCol="features", labelCol="label", predictionCol="p @since("1.6.0") def setParams(self, *, featuresCol="features", labelCol="label", predictionCol="prediction", fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", - quantileProbabilities=list([0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]), + quantileProbabilities=list([0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]), # noqa: B005 quantilesCol=None, aggregationDepth=2, maxBlockSizeInMB=0.0): """ setParams(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \ diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 2c083182de470..2bddfe822f29e 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -835,13 +835,13 @@ class CrossValidatorModel(Model, _CrossValidatorParams, MLReadable, MLWritable): .. versionadded:: 1.4.0 """ - def __init__(self, bestModel, avgMetrics=[], subModels=None): + def __init__(self, bestModel, avgMetrics=None, subModels=None): super(CrossValidatorModel, self).__init__() #: best model from cross validation self.bestModel = bestModel #: Average cross-validation metrics for each paramMap in #: CrossValidator.estimatorParamMaps, in the corresponding order. - self.avgMetrics = avgMetrics + self.avgMetrics = avgMetrics or [] #: sub model list from cross validation self.subModels = subModels @@ -1323,12 +1323,12 @@ class TrainValidationSplitModel(Model, _TrainValidationSplitParams, MLReadable, .. versionadded:: 2.0.0 """ - def __init__(self, bestModel, validationMetrics=[], subModels=None): + def __init__(self, bestModel, validationMetrics=None, subModels=None): super(TrainValidationSplitModel, self).__init__() #: best model from train validation split self.bestModel = bestModel #: evaluated validation metrics - self.validationMetrics = validationMetrics + self.validationMetrics = validationMetrics or [] #: sub models from train validation split self.subModels = subModels diff --git a/python/pyspark/ml/tuning.pyi b/python/pyspark/ml/tuning.pyi index e5f153d49e9c6..912abd4d7124a 100644 --- a/python/pyspark/ml/tuning.pyi +++ b/python/pyspark/ml/tuning.pyi @@ -104,7 +104,7 @@ class CrossValidatorModel( def __init__( self, bestModel: Model, - avgMetrics: List[float] = ..., + avgMetrics: Optional[List[float]] = ..., subModels: Optional[List[List[Model]]] = ..., ) -> None: ... def copy(self, extra: Optional[ParamMap] = ...) -> CrossValidatorModel: ... @@ -171,7 +171,7 @@ class TrainValidationSplitModel( def __init__( self, bestModel: Model, - validationMetrics: List[float] = ..., + validationMetrics: Optional[List[float]] = ..., subModels: Optional[List[Model]] = ..., ) -> None: ... def setEstimator(self, value: Estimator) -> TrainValidationSplitModel: ... diff --git a/python/pyspark/resource/profile.py b/python/pyspark/resource/profile.py index 1c59a1c4a123c..38a68bc74d97e 100644 --- a/python/pyspark/resource/profile.py +++ b/python/pyspark/resource/profile.py @@ -34,13 +34,13 @@ class ResourceProfile(object): This API is evolving. """ - def __init__(self, _java_resource_profile=None, _exec_req={}, _task_req={}): + def __init__(self, _java_resource_profile=None, _exec_req=None, _task_req=None): if _java_resource_profile is not None: self._java_resource_profile = _java_resource_profile else: self._java_resource_profile = None - self._executor_resource_requests = _exec_req - self._task_resource_requests = _task_req + self._executor_resource_requests = _exec_req or {} + self._task_resource_requests = _task_req or {} @property def id(self): diff --git a/python/pyspark/resource/profile.pyi b/python/pyspark/resource/profile.pyi index 04838692436df..c8f23a5cac370 100644 --- a/python/pyspark/resource/profile.pyi +++ b/python/pyspark/resource/profile.pyi @@ -22,7 +22,7 @@ from pyspark.resource.requests import ( # noqa: F401 TaskResourceRequest as TaskResourceRequest, TaskResourceRequests as TaskResourceRequests, ) -from typing import overload, Dict, Union +from typing import overload, Dict, Union, Optional from py4j.java_gateway import JavaObject # type: ignore[import] class ResourceProfile: @@ -35,8 +35,8 @@ class ResourceProfile: def __init__( self, _java_resource_profile: None = ..., - _exec_req: Dict[str, ExecutorResourceRequest] = ..., - _task_req: Dict[str, TaskResourceRequest] = ..., + _exec_req: Optional[Dict[str, ExecutorResourceRequest]] = ..., + _task_req: Optional[Dict[str, TaskResourceRequest]] = ..., ) -> None: ... @property def id(self) -> int: ... diff --git a/python/pyspark/sql/avro/functions.py b/python/pyspark/sql/avro/functions.py index ce322814e34f8..7e4ceb20cd2c4 100644 --- a/python/pyspark/sql/avro/functions.py +++ b/python/pyspark/sql/avro/functions.py @@ -25,7 +25,7 @@ from pyspark.util import _print_missing_jar -def from_avro(data, jsonFormatSchema, options={}): +def from_avro(data, jsonFormatSchema, options=None): """ Converts a binary column of Avro format into its corresponding catalyst value. The specified schema must match the read data, otherwise the behavior is undefined: @@ -70,7 +70,7 @@ def from_avro(data, jsonFormatSchema, options={}): sc = SparkContext._active_spark_context try: jc = sc._jvm.org.apache.spark.sql.avro.functions.from_avro( - _to_java_column(data), jsonFormatSchema, options) + _to_java_column(data), jsonFormatSchema, options or {}) except TypeError as e: if str(e) == "'JavaPackage' object is not callable": _print_missing_jar("Avro", "avro", "avro", sc.version) diff --git a/python/pyspark/sql/avro/functions.pyi b/python/pyspark/sql/avro/functions.pyi index 4c2e3814a9e94..49881335d8fcc 100644 --- a/python/pyspark/sql/avro/functions.pyi +++ b/python/pyspark/sql/avro/functions.pyi @@ -16,12 +16,12 @@ # specific language governing permissions and limitations # under the License. -from typing import Dict +from typing import Dict, Optional from pyspark.sql._typing import ColumnOrName from pyspark.sql.column import Column def from_avro( - data: ColumnOrName, jsonFormatSchema: str, options: Dict[str, str] = ... + data: ColumnOrName, jsonFormatSchema: str, options: Optional[Dict[str, str]] = ... ) -> Column: ... def to_avro(data: ColumnOrName, jsonFormatSchema: str = ...) -> Column: ... diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 4dc3129fd6bc2..f612d2d0366f2 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -80,8 +80,10 @@ def _invoke_binary_math_function(name, col1, col2): ) -def _options_to_str(options): - return {key: to_str(value) for (key, value) in options.items()} +def _options_to_str(options=None): + if options: + return {key: to_str(value) for (key, value) in options.items()} + return {} def lit(col): @@ -3454,7 +3456,7 @@ def json_tuple(col, *fields): return Column(jc) -def from_json(col, schema, options={}): +def from_json(col, schema, options=None): """ Parses a column containing a JSON string into a :class:`MapType` with :class:`StringType` as keys type, :class:`StructType` or :class:`ArrayType` with @@ -3510,7 +3512,7 @@ def from_json(col, schema, options={}): return Column(jc) -def to_json(col, options={}): +def to_json(col, options=None): """ Converts a column containing a :class:`StructType`, :class:`ArrayType` or a :class:`MapType` into a JSON string. Throws an exception, in the case of an unsupported type. @@ -3557,7 +3559,7 @@ def to_json(col, options={}): return Column(jc) -def schema_of_json(json, options={}): +def schema_of_json(json, options=None): """ Parses a JSON string and infers its schema in DDL format. @@ -3594,7 +3596,7 @@ def schema_of_json(json, options={}): return Column(jc) -def schema_of_csv(csv, options={}): +def schema_of_csv(csv, options=None): """ Parses a CSV string and infers its schema in DDL format. @@ -3627,7 +3629,7 @@ def schema_of_csv(csv, options={}): return Column(jc) -def to_csv(col, options={}): +def to_csv(col, options=None): """ Converts a column containing a :class:`StructType` into a CSV string. Throws an exception, in the case of an unsupported type. @@ -4038,7 +4040,7 @@ def sequence(start, stop, step=None): _to_java_column(start), _to_java_column(stop), _to_java_column(step))) -def from_csv(col, schema, options={}): +def from_csv(col, schema, options=None): """ Parses a column containing a CSV string to a row with the specified schema. Returns `null`, in the case of an unparseable string. diff --git a/python/pyspark/sql/functions.pyi b/python/pyspark/sql/functions.pyi index 50e178df9996f..acb17a2657d00 100644 --- a/python/pyspark/sql/functions.pyi +++ b/python/pyspark/sql/functions.pyi @@ -196,12 +196,12 @@ def json_tuple(col: ColumnOrName, *fields: str) -> Column: ... def from_json( col: ColumnOrName, schema: Union[ArrayType, StructType, Column, str], - options: Dict[str, str] = ..., + options: Optional[Dict[str, str]] = ..., ) -> Column: ... -def to_json(col: ColumnOrName, options: Dict[str, str] = ...) -> Column: ... -def schema_of_json(json: ColumnOrName, options: Dict[str, str] = ...) -> Column: ... -def schema_of_csv(csv: ColumnOrName, options: Dict[str, str] = ...) -> Column: ... -def to_csv(col: ColumnOrName, options: Dict[str, str] = ...) -> Column: ... +def to_json(col: ColumnOrName, options: Optional[Dict[str, str]] = ...) -> Column: ... +def schema_of_json(json: ColumnOrName, options: Optional[Dict[str, str]] = ...) -> Column: ... +def schema_of_csv(csv: ColumnOrName, options: Optional[Dict[str, str]] = ...) -> Column: ... +def to_csv(col: ColumnOrName, options: Optional[Dict[str, str]] = ...) -> Column: ... def size(col: ColumnOrName) -> Column: ... def array_min(col: ColumnOrName) -> Column: ... def array_max(col: ColumnOrName) -> Column: ... @@ -223,7 +223,7 @@ def sequence( def from_csv( col: ColumnOrName, schema: Union[StructType, Column, str], - options: Dict[str, str] = ..., + options: Optional[Dict[str, str]] = ..., ) -> Column: ... @overload def transform(col: ColumnOrName, f: Callable[[Column], Column]) -> Column: ...