Skip to content

Commit 7f53086

Browse files
committed
support transform(), refactor and cleanup
1 parent df098fc commit 7f53086

File tree

11 files changed

+384
-905
lines changed

11 files changed

+384
-905
lines changed

core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@ private[spark] class PythonRDD(
5252
accumulator: Accumulator[JList[Array[Byte]]])
5353
extends RDD[Array[Byte]](parent) {
5454

55+
def copyTo(rdd: RDD[_]): PythonRDD = {
56+
new PythonRDD(rdd, command, envVars, pythonIncludes, preservePartitoning, pythonExec, broadcastVars, accumulator)
57+
}
58+
5559
val bufferSize = conf.getInt("spark.buffer.size", 65536)
5660
val reuse_worker = conf.getBoolean("spark.python.worker.reuse", true)
5761

python/pyspark/streaming/context.py

Lines changed: 33 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer
2424
from pyspark.context import SparkContext
2525
from pyspark.streaming.dstream import DStream
26-
from pyspark.streaming.duration import Duration
26+
from pyspark.streaming.duration import Duration, Seconds
2727

2828
from py4j.java_collections import ListConverter
2929

@@ -35,68 +35,31 @@ class StreamingContext(object):
3535
broadcast variables on that cluster.
3636
"""
3737

38-
def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None,
39-
environment=None, batchSize=1024, serializer=PickleSerializer(), conf=None,
40-
gateway=None, sparkContext=None, duration=None):
38+
def __init__(self, sparkContext, duration):
4139
"""
4240
Create a new StreamingContext. At least the master and app name and duration
4341
should be set, either through the named parameters here or through C{conf}.
4442
45-
@param master: Cluster URL to connect to
46-
(e.g. mesos://host:port, spark://host:port, local[4]).
47-
@param appName: A name for your job, to display on the cluster web UI.
48-
@param sparkHome: Location where Spark is installed on cluster nodes.
49-
@param pyFiles: Collection of .zip or .py files to send to the cluster
50-
and add to PYTHONPATH. These can be paths on the local file
51-
system or HDFS, HTTP, HTTPS, or FTP URLs.
52-
@param environment: A dictionary of environment variables to set on
53-
worker nodes.
54-
@param batchSize: The number of Python objects represented as a single
55-
Java object. Set 1 to disable batching or -1 to use an
56-
unlimited batch size.
57-
@param serializer: The serializer for RDDs.
58-
@param conf: A L{SparkConf} object setting Spark properties.
59-
@param gateway: Use an existing gateway and JVM, otherwise a new JVM
60-
will be instatiated.
6143
@param sparkContext: L{SparkContext} object.
62-
@param duration: A L{Duration} object for SparkStreaming.
44+
@param duration: A L{Duration} object or seconds for SparkStreaming.
6345
6446
"""
47+
if isinstance(duration, (int, long, float)):
48+
duration = Seconds(duration)
6549

66-
if not isinstance(duration, Duration):
67-
raise TypeError("Input should be pyspark.streaming.duration.Duration object")
68-
69-
if sparkContext is None:
70-
# Create the Python Sparkcontext
71-
self._sc = SparkContext(master=master, appName=appName, sparkHome=sparkHome,
72-
pyFiles=pyFiles, environment=environment, batchSize=batchSize,
73-
serializer=serializer, conf=conf, gateway=gateway)
74-
else:
75-
self._sc = sparkContext
76-
77-
# Start py4j callback server.
78-
# Callback sever is need only by SparkStreming; therefore the callback sever
79-
# is started in StreamingContext.
80-
SparkContext._gateway.restart_callback_server()
81-
self._set_clean_up_handler()
50+
self._sc = sparkContext
8251
self._jvm = self._sc._jvm
83-
self._jssc = self._initialize_context(self._sc._jsc, duration._jduration)
52+
self._start_callback_server()
53+
self._jssc = self._initialize_context(self._sc, duration)
8454

85-
# Initialize StremaingContext in function to allow subclass specific initialization
86-
def _initialize_context(self, jspark_context, jduration):
87-
return self._jvm.JavaStreamingContext(jspark_context, jduration)
55+
def _start_callback_server(self):
56+
gw = self._sc._gateway
57+
# getattr will fallback to JVM
58+
if "_callback_server" not in gw.__dict__:
59+
gw._start_callback_server(gw._python_proxy_port)
8860

89-
def _set_clean_up_handler(self):
90-
""" set clean up hander using atexit """
91-
92-
def clean_up_handler():
93-
SparkContext._gateway.shutdown()
94-
95-
atexit.register(clean_up_handler)
96-
# atext is not called when the program is killed by a signal not handled by
97-
# Python.
98-
for sig in (SIGINT, SIGTERM):
99-
signal(sig, clean_up_handler)
61+
def _initialize_context(self, sc, duration):
62+
return self._jvm.JavaStreamingContext(sc._jsc, duration._jduration)
10063

10164
@property
10265
def sparkContext(self):
@@ -121,17 +84,26 @@ def awaitTermination(self, timeout=None):
12184
else:
12285
self._jssc.awaitTermination(timeout)
12386

87+
def stop(self, stopSparkContext=True, stopGraceFully=False):
88+
"""
89+
Stop the execution of the streams immediately (does not wait for all received data
90+
to be processed).
91+
"""
92+
self._jssc.stop(stopSparkContext, stopGraceFully)
93+
if stopSparkContext:
94+
self._sc.stop()
95+
12496
def remember(self, duration):
12597
"""
12698
Set each DStreams in this context to remember RDDs it generated in the last given duration.
12799
DStreams remember RDDs only for a limited duration of time and releases them for garbage
128100
collection. This method allows the developer to specify how to long to remember the RDDs (
129101
if the developer wishes to query old data outside the DStream computation).
130-
@param duration pyspark.streaming.duration.Duration object.
102+
@param duration pyspark.streaming.duration.Duration object or seconds.
131103
Minimum duration that each DStream should remember its RDDs
132104
"""
133-
if not isinstance(duration, Duration):
134-
raise TypeError("Input should be pyspark.streaming.duration.Duration object")
105+
if isinstance(duration, (int, long, float)):
106+
duration = Seconds(duration)
135107

136108
self._jssc.remember(duration._jduration)
137109

@@ -153,34 +125,14 @@ def textFileStream(self, directory):
153125
"""
154126
return DStream(self._jssc.textFileStream(directory), self, UTF8Deserializer())
155127

156-
def stop(self, stopSparkContext=True, stopGraceFully=False):
157-
"""
158-
Stop the execution of the streams immediately (does not wait for all received data
159-
to be processed).
160-
"""
161-
self._jssc.stop(stopSparkContext, stopGraceFully)
162-
if stopSparkContext:
163-
self._sc.stop()
164-
165-
# Shutdown only callback server and all py3j client is shutdowned
166-
# clean up handler
167-
SparkContext._gateway._shutdown_callback_server()
168-
169-
def _testInputStream(self, test_inputs, numSlices=None):
128+
def _makeStream(self, inputs, numSlices=None):
170129
"""
171130
This function is only for unittest.
172131
It requires a list as input, and returns the i_th element at the i_th batch
173132
under manual clock.
174133
"""
175-
test_rdds = list()
176-
test_rdd_deserializers = list()
177-
for test_input in test_inputs:
178-
test_rdd = self._sc.parallelize(test_input, numSlices)
179-
test_rdds.append(test_rdd._jrdd)
180-
test_rdd_deserializers.append(test_rdd._jrdd_deserializer)
181-
# All deserializers have to be the same.
182-
# TODO: add deserializer validation
183-
jtest_rdds = ListConverter().convert(test_rdds, SparkContext._gateway._gateway_client)
184-
jinput_stream = self._jvm.PythonTestInputStream(self._jssc, jtest_rdds).asJavaDStream()
185-
186-
return DStream(jinput_stream, self, test_rdd_deserializers[0])
134+
rdds = [self._sc.parallelize(input, numSlices) for input in inputs]
135+
jrdds = ListConverter().convert([r._jrdd for r in rdds],
136+
SparkContext._gateway._gateway_client)
137+
jdstream = self._jvm.PythonDataInputStream(self._jssc, jrdds).asJavaDStream()
138+
return DStream(jdstream, self, rdds[0]._jrdd_deserializer)

0 commit comments

Comments
 (0)