2323from pyspark .serializers import PickleSerializer , BatchedSerializer , UTF8Deserializer
2424from pyspark .context import SparkContext
2525from pyspark .streaming .dstream import DStream
26- from pyspark .streaming .duration import Duration
26+ from pyspark .streaming .duration import Duration , Seconds
2727
2828from py4j .java_collections import ListConverter
2929
@@ -35,68 +35,31 @@ class StreamingContext(object):
3535 broadcast variables on that cluster.
3636 """
3737
38- def __init__ (self , master = None , appName = None , sparkHome = None , pyFiles = None ,
39- environment = None , batchSize = 1024 , serializer = PickleSerializer (), conf = None ,
40- gateway = None , sparkContext = None , duration = None ):
38+ def __init__ (self , sparkContext , duration ):
4139 """
4240 Create a new StreamingContext. At least the master and app name and duration
4341 should be set, either through the named parameters here or through C{conf}.
4442
45- @param master: Cluster URL to connect to
46- (e.g. mesos://host:port, spark://host:port, local[4]).
47- @param appName: A name for your job, to display on the cluster web UI.
48- @param sparkHome: Location where Spark is installed on cluster nodes.
49- @param pyFiles: Collection of .zip or .py files to send to the cluster
50- and add to PYTHONPATH. These can be paths on the local file
51- system or HDFS, HTTP, HTTPS, or FTP URLs.
52- @param environment: A dictionary of environment variables to set on
53- worker nodes.
54- @param batchSize: The number of Python objects represented as a single
55- Java object. Set 1 to disable batching or -1 to use an
56- unlimited batch size.
57- @param serializer: The serializer for RDDs.
58- @param conf: A L{SparkConf} object setting Spark properties.
59- @param gateway: Use an existing gateway and JVM, otherwise a new JVM
60- will be instatiated.
6143 @param sparkContext: L{SparkContext} object.
62- @param duration: A L{Duration} object for SparkStreaming.
44+ @param duration: A L{Duration} object or seconds for SparkStreaming.
6345
6446 """
47+ if isinstance (duration , (int , long , float )):
48+ duration = Seconds (duration )
6549
66- if not isinstance (duration , Duration ):
67- raise TypeError ("Input should be pyspark.streaming.duration.Duration object" )
68-
69- if sparkContext is None :
70- # Create the Python Sparkcontext
71- self ._sc = SparkContext (master = master , appName = appName , sparkHome = sparkHome ,
72- pyFiles = pyFiles , environment = environment , batchSize = batchSize ,
73- serializer = serializer , conf = conf , gateway = gateway )
74- else :
75- self ._sc = sparkContext
76-
77- # Start py4j callback server.
78- # Callback sever is need only by SparkStreming; therefore the callback sever
79- # is started in StreamingContext.
80- SparkContext ._gateway .restart_callback_server ()
81- self ._set_clean_up_handler ()
50+ self ._sc = sparkContext
8251 self ._jvm = self ._sc ._jvm
83- self ._jssc = self ._initialize_context (self ._sc ._jsc , duration ._jduration )
52+ self ._start_callback_server ()
53+ self ._jssc = self ._initialize_context (self ._sc , duration )
8454
85- # Initialize StremaingContext in function to allow subclass specific initialization
86- def _initialize_context (self , jspark_context , jduration ):
87- return self ._jvm .JavaStreamingContext (jspark_context , jduration )
55+ def _start_callback_server (self ):
56+ gw = self ._sc ._gateway
57+ # getattr will fallback to JVM
58+ if "_callback_server" not in gw .__dict__ :
59+ gw ._start_callback_server (gw ._python_proxy_port )
8860
89- def _set_clean_up_handler (self ):
90- """ set clean up hander using atexit """
91-
92- def clean_up_handler ():
93- SparkContext ._gateway .shutdown ()
94-
95- atexit .register (clean_up_handler )
96- # atext is not called when the program is killed by a signal not handled by
97- # Python.
98- for sig in (SIGINT , SIGTERM ):
99- signal (sig , clean_up_handler )
61+ def _initialize_context (self , sc , duration ):
62+ return self ._jvm .JavaStreamingContext (sc ._jsc , duration ._jduration )
10063
10164 @property
10265 def sparkContext (self ):
@@ -121,17 +84,26 @@ def awaitTermination(self, timeout=None):
12184 else :
12285 self ._jssc .awaitTermination (timeout )
12386
87+ def stop (self , stopSparkContext = True , stopGraceFully = False ):
88+ """
89+ Stop the execution of the streams immediately (does not wait for all received data
90+ to be processed).
91+ """
92+ self ._jssc .stop (stopSparkContext , stopGraceFully )
93+ if stopSparkContext :
94+ self ._sc .stop ()
95+
12496 def remember (self , duration ):
12597 """
12698 Set each DStreams in this context to remember RDDs it generated in the last given duration.
12799 DStreams remember RDDs only for a limited duration of time and releases them for garbage
128100 collection. This method allows the developer to specify how to long to remember the RDDs (
129101 if the developer wishes to query old data outside the DStream computation).
130- @param duration pyspark.streaming.duration.Duration object.
102+ @param duration pyspark.streaming.duration.Duration object or seconds .
131103 Minimum duration that each DStream should remember its RDDs
132104 """
133- if not isinstance (duration , Duration ):
134- raise TypeError ( "Input should be pyspark.streaming. duration.Duration object" )
105+ if isinstance (duration , ( int , long , float ) ):
106+ duration = Seconds ( duration )
135107
136108 self ._jssc .remember (duration ._jduration )
137109
@@ -153,34 +125,14 @@ def textFileStream(self, directory):
153125 """
154126 return DStream (self ._jssc .textFileStream (directory ), self , UTF8Deserializer ())
155127
156- def stop (self , stopSparkContext = True , stopGraceFully = False ):
157- """
158- Stop the execution of the streams immediately (does not wait for all received data
159- to be processed).
160- """
161- self ._jssc .stop (stopSparkContext , stopGraceFully )
162- if stopSparkContext :
163- self ._sc .stop ()
164-
165- # Shutdown only callback server and all py3j client is shutdowned
166- # clean up handler
167- SparkContext ._gateway ._shutdown_callback_server ()
168-
169- def _testInputStream (self , test_inputs , numSlices = None ):
128+ def _makeStream (self , inputs , numSlices = None ):
170129 """
171130 This function is only for unittest.
172131 It requires a list as input, and returns the i_th element at the i_th batch
173132 under manual clock.
174133 """
175- test_rdds = list ()
176- test_rdd_deserializers = list ()
177- for test_input in test_inputs :
178- test_rdd = self ._sc .parallelize (test_input , numSlices )
179- test_rdds .append (test_rdd ._jrdd )
180- test_rdd_deserializers .append (test_rdd ._jrdd_deserializer )
181- # All deserializers have to be the same.
182- # TODO: add deserializer validation
183- jtest_rdds = ListConverter ().convert (test_rdds , SparkContext ._gateway ._gateway_client )
184- jinput_stream = self ._jvm .PythonTestInputStream (self ._jssc , jtest_rdds ).asJavaDStream ()
185-
186- return DStream (jinput_stream , self , test_rdd_deserializers [0 ])
134+ rdds = [self ._sc .parallelize (input , numSlices ) for input in inputs ]
135+ jrdds = ListConverter ().convert ([r ._jrdd for r in rdds ],
136+ SparkContext ._gateway ._gateway_client )
137+ jdstream = self ._jvm .PythonDataInputStream (self ._jssc , jrdds ).asJavaDStream ()
138+ return DStream (jdstream , self , rdds [0 ]._jrdd_deserializer )
0 commit comments