diff --git a/dev/create-release/releaseutils.py b/dev/create-release/releaseutils.py index ab812e1bb7c0..8cc990d87184 100755 --- a/dev/create-release/releaseutils.py +++ b/dev/create-release/releaseutils.py @@ -50,7 +50,7 @@ sys.exit(-1) if sys.version < '3': - input = raw_input + input = raw_input # noqa # Contributors list file name contributors_file_name = "contributors.txt" @@ -152,7 +152,11 @@ def get_commits(tag): if not is_valid_author(author): author = github_username # Guard against special characters - author = unidecode.unidecode(unicode(author, "UTF-8")).strip() + try: # Python 2 + author = unicode(author, "UTF-8") + except NameError: # Python 3 + author = str(author) + author = unidecode.unidecode(author).strip() commit = Commit(_hash, author, title, pr_number) commits.append(commit) return commits diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index fe05282efdd4..28a6714856c1 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -40,7 +40,7 @@ JIRA_IMPORTED = False if sys.version < '3': - input = raw_input + input = raw_input # noqa # Location of your Spark git development area SPARK_HOME = os.environ.get("SPARK_HOME", os.getcwd()) diff --git a/python/pyspark/sql/conf.py b/python/pyspark/sql/conf.py index f80bf598c221..71ea1631718f 100644 --- a/python/pyspark/sql/conf.py +++ b/python/pyspark/sql/conf.py @@ -20,6 +20,9 @@ from pyspark import since, _NoValue from pyspark.rdd import ignore_unicode_prefix +if sys.version_info[0] >= 3: + basestring = str + class RuntimeConfig(object): """User-facing configuration API, accessible through `SparkSession.conf`. @@ -59,7 +62,7 @@ def unset(self, key): def _checkType(self, obj, identifier): """Assert that an object is of type str.""" - if not isinstance(obj, str) and not isinstance(obj, unicode): + if not isinstance(obj, basestring): raise TypeError("expected %s '%s' to be a string (was '%s')" % (identifier, obj, type(obj).__name__)) diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 8c1fd4af674d..ee13778a7dcd 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -19,10 +19,7 @@ import json if sys.version >= '3': - intlike = int - basestring = unicode = str -else: - intlike = (int, long) + basestring = str from py4j.java_gateway import java_import diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index 59977dcb435a..ce42a857d0c0 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -23,6 +23,8 @@ if sys.version < "3": from itertools import imap as map, ifilter as filter +else: + long = int from py4j.protocol import Py4JJavaError diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 09af47a597be..5cef621a28e6 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -179,7 +179,7 @@ def func(dstream): self._test_func(input, func, expected) def test_flatMap(self): - """Basic operation test for DStream.faltMap.""" + """Basic operation test for DStream.flatMap.""" input = [range(1, 5), range(5, 9), range(9, 13)] def func(dstream): @@ -206,6 +206,38 @@ def func(dstream): expected = [[len(x)] for x in input] self._test_func(input, func, expected) + def test_slice(self): + """Basic operation test for DStream.slice.""" + import datetime as dt + self.ssc = StreamingContext(self.sc, 1.0) + self.ssc.remember(4.0) + input = [[1], [2], [3], [4]] + stream = self.ssc.queueStream([self.sc.parallelize(d, 1) for d in input]) + + time_vals = [] + + def get_times(t, rdd): + if rdd and len(time_vals) < len(input): + time_vals.append(t) + + stream.foreachRDD(get_times) + + self.ssc.start() + self.wait_for(time_vals, 4) + begin_time = time_vals[0] + + def get_sliced(begin_delta, end_delta): + begin = begin_time + dt.timedelta(seconds=begin_delta) + end = begin_time + dt.timedelta(seconds=end_delta) + rdds = stream.slice(begin, end) + result_list = [rdd.collect() for rdd in rdds] + return [r for result in result_list for r in result] + + self.assertEqual(set([1]), set(get_sliced(0, 0))) + self.assertEqual(set([2, 3]), set(get_sliced(1, 2))) + self.assertEqual(set([2, 3, 4]), set(get_sliced(1, 4))) + self.assertEqual(set([1, 2, 3, 4]), set(get_sliced(0, 4))) + def test_reduce(self): """Basic operation test for DStream.reduce.""" input = [range(1, 5), range(5, 9), range(9, 13)] diff --git a/sql/hive/src/test/resources/data/scripts/dumpdata_script.py b/sql/hive/src/test/resources/data/scripts/dumpdata_script.py index 341a1b40e07a..5b360208d36f 100644 --- a/sql/hive/src/test/resources/data/scripts/dumpdata_script.py +++ b/sql/hive/src/test/resources/data/scripts/dumpdata_script.py @@ -18,6 +18,9 @@ # import sys +if sys.version_info[0] >= 3: + xrange = range + for i in xrange(50): for j in xrange(5): for k in xrange(20022):