Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 36 additions & 6 deletions python/pyspark/cloudpickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
import itertools
from copy_reg import _extension_registry, _inverted_registry, _extension_cache
import new
import dis
import traceback
import platform

Expand All @@ -61,6 +62,14 @@
import logging
cloudLog = logging.getLogger("Cloud.Transport")

#relevant opcodes
STORE_GLOBAL = chr(dis.opname.index('STORE_GLOBAL'))
DELETE_GLOBAL = chr(dis.opname.index('DELETE_GLOBAL'))
LOAD_GLOBAL = chr(dis.opname.index('LOAD_GLOBAL'))
GLOBAL_OPS = [STORE_GLOBAL, DELETE_GLOBAL, LOAD_GLOBAL]

HAVE_ARGUMENT = chr(dis.HAVE_ARGUMENT)
EXTENDED_ARG = chr(dis.EXTENDED_ARG)

if PyImp == "PyPy":
# register builtin type in `new`
Expand Down Expand Up @@ -304,16 +313,37 @@ def save_function_tuple(self, func, forced_imports):
write(pickle.REDUCE) # applies _fill_function on the tuple

@staticmethod
def extract_code_globals(code):
def extract_code_globals(co):
"""
Find all globals names read or written to by codeblock co
"""
names = set(code.co_names)
if code.co_consts: # see if nested function have any global refs
for const in code.co_consts:
code = co.co_code
names = co.co_names
out_names = set()

n = len(code)
i = 0
extended_arg = 0
while i < n:
op = code[i]

i = i+1
if op >= HAVE_ARGUMENT:
oparg = ord(code[i]) + ord(code[i+1])*256 + extended_arg
extended_arg = 0
i = i+2
if op == EXTENDED_ARG:
extended_arg = oparg*65536L
if op in GLOBAL_OPS:
out_names.add(names[oparg])
#print 'extracted', out_names, ' from ', names

if co.co_consts: # see if nested function have any global refs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on a read through this PR, it looks like this line is the first place where this function diverges from the pre-#2144 version of cloudpickle.

It looks like the original version of cloudpickle called this code from outside of extract_code_globals, so I guess the old code would only perform one level of recursion when trying to extract globals?

Do you think that adding actual, unbounded recursion could cause problems here? If the "nested function" implies that this only applies to functions defined within other functions, then there aren't cycles in the nesting and therefore shouldn't be cycles that lead to infinite recursion.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original code only perform two levels of functions, the new version can handle multiple levels.

Each level is on function code, which is created by def or lambda, so I think the code cannot be recursive.

for const in co.co_consts:
if type(const) is types.CodeType:
names |= CloudPickler.extract_code_globals(const)
return names
out_names |= CloudPickler.extract_code_globals(const)

return out_names

def extract_func_data(self, func):
"""
Expand Down
18 changes: 18 additions & 0 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,24 @@ def test_pickling_file_handles(self):
out2 = ser.loads(ser.dumps(out1))
self.assertEquals(out1, out2)

def test_func_globals(self):

class Unpicklable(object):
def __reduce__(self):
raise Exception("not picklable")

global exit
exit = Unpicklable()

ser = CloudPickleSerializer()
self.assertRaises(Exception, lambda: ser.dumps(exit))

def foo():
sys.exit(0)

self.assertTrue("exit" in foo.func_code.co_names)
ser.dumps(foo)


class PySparkTestCase(unittest.TestCase):

Expand Down