Skip to content

Commit bb96012

Browse files
daviesJoshRosen
authored andcommitted
[SPARK-3679] [PySpark] pickle the exact globals of functions
function.func_code.co_names has all the names used in the function, including name of attributes. It will pickle some unnecessary globals if there is a global having the same name with attribute (in co_names). There is a regression introduced by apache#2144, revert part of changes in that PR. cc JoshRosen Author: Davies Liu <[email protected]> Closes apache#2522 from davies/globals and squashes the following commits: dfbccf5 [Davies Liu] fix bug while pickle globals of function
1 parent c854b9f commit bb96012

File tree

2 files changed

+54
-6
lines changed

2 files changed

+54
-6
lines changed

python/pyspark/cloudpickle.py

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
import itertools
5353
from copy_reg import _extension_registry, _inverted_registry, _extension_cache
5454
import new
55+
import dis
5556
import traceback
5657
import platform
5758

@@ -61,6 +62,14 @@
6162
import logging
6263
cloudLog = logging.getLogger("Cloud.Transport")
6364

65+
#relevant opcodes
66+
STORE_GLOBAL = chr(dis.opname.index('STORE_GLOBAL'))
67+
DELETE_GLOBAL = chr(dis.opname.index('DELETE_GLOBAL'))
68+
LOAD_GLOBAL = chr(dis.opname.index('LOAD_GLOBAL'))
69+
GLOBAL_OPS = [STORE_GLOBAL, DELETE_GLOBAL, LOAD_GLOBAL]
70+
71+
HAVE_ARGUMENT = chr(dis.HAVE_ARGUMENT)
72+
EXTENDED_ARG = chr(dis.EXTENDED_ARG)
6473

6574
if PyImp == "PyPy":
6675
# register builtin type in `new`
@@ -304,16 +313,37 @@ def save_function_tuple(self, func, forced_imports):
304313
write(pickle.REDUCE) # applies _fill_function on the tuple
305314

306315
@staticmethod
307-
def extract_code_globals(code):
316+
def extract_code_globals(co):
308317
"""
309318
Find all globals names read or written to by codeblock co
310319
"""
311-
names = set(code.co_names)
312-
if code.co_consts: # see if nested function have any global refs
313-
for const in code.co_consts:
320+
code = co.co_code
321+
names = co.co_names
322+
out_names = set()
323+
324+
n = len(code)
325+
i = 0
326+
extended_arg = 0
327+
while i < n:
328+
op = code[i]
329+
330+
i = i+1
331+
if op >= HAVE_ARGUMENT:
332+
oparg = ord(code[i]) + ord(code[i+1])*256 + extended_arg
333+
extended_arg = 0
334+
i = i+2
335+
if op == EXTENDED_ARG:
336+
extended_arg = oparg*65536L
337+
if op in GLOBAL_OPS:
338+
out_names.add(names[oparg])
339+
#print 'extracted', out_names, ' from ', names
340+
341+
if co.co_consts: # see if nested function have any global refs
342+
for const in co.co_consts:
314343
if type(const) is types.CodeType:
315-
names |= CloudPickler.extract_code_globals(const)
316-
return names
344+
out_names |= CloudPickler.extract_code_globals(const)
345+
346+
return out_names
317347

318348
def extract_func_data(self, func):
319349
"""

python/pyspark/tests.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,24 @@ def test_pickling_file_handles(self):
213213
out2 = ser.loads(ser.dumps(out1))
214214
self.assertEquals(out1, out2)
215215

216+
def test_func_globals(self):
217+
218+
class Unpicklable(object):
219+
def __reduce__(self):
220+
raise Exception("not picklable")
221+
222+
global exit
223+
exit = Unpicklable()
224+
225+
ser = CloudPickleSerializer()
226+
self.assertRaises(Exception, lambda: ser.dumps(exit))
227+
228+
def foo():
229+
sys.exit(0)
230+
231+
self.assertTrue("exit" in foo.func_code.co_names)
232+
ser.dumps(foo)
233+
216234

217235
class PySparkTestCase(unittest.TestCase):
218236

0 commit comments

Comments
 (0)