Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add nested dict support #208

Merged
merged 20 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 45 additions & 23 deletions lupa/_lupa.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,11 @@ cdef object exc_info
from sys import exc_info

cdef object Mapping
cdef object Sequence
try:
from collections.abc import Mapping
from collections.abc import Mapping, Sequence
except ImportError:
from collections import Mapping # Py2
from collections import Mapping, Sequence # Py2

cdef object wraps
from functools import wraps
Expand Down Expand Up @@ -169,6 +170,10 @@ def lua_type(obj):
lua.lua_settop(L, old_top)
unlock_runtime(lua_object._runtime)

cdef inline int _len_as_int(Py_ssize_t obj) except -1:
if obj > <Py_ssize_t>INT_MAX:
raise OverflowError
return <int>obj

@cython.no_gc_clear
cdef class LuaRuntime:
Expand Down Expand Up @@ -520,21 +525,22 @@ cdef class LuaRuntime:
"""
return self.table_from(items, kwargs)

def table_from(self, *args):
def table_from(self, *args, bint recursive=False):
"""Create a new table from Python mapping or iterable.

table_from() accepts either a dict/mapping or an iterable with items.
Items from dicts are set as key-value pairs; items from iterables
are placed in the table in order.

Nested mappings / iterables are passed to Lua as userdata
(wrapped Python objects); they are not converted to Lua tables.
(wrapped Python objects). If `recursive` is False (the default),
they are not converted to Lua tables.
"""
assert self._state is not NULL
cdef lua_State *L = self._state
lock_runtime(self)
try:
return py_to_lua_table(self, L, args)
return py_to_lua_table(self, L, args, recursive=recursive)
finally:
unlock_runtime(self)

Expand Down Expand Up @@ -1236,7 +1242,7 @@ cdef object resume_lua_thread(_LuaThread thread, tuple args):
# already terminated
raise StopIteration
if args:
nargs = len(args)
nargs = _len_as_int(len(args))
push_lua_arguments(thread._runtime, co, args)
with nogil:
status = lua.lua_resume(co, L, nargs, &nres)
Expand Down Expand Up @@ -1482,7 +1488,7 @@ cdef py_object* unpack_userdata(lua_State *L, int n) noexcept nogil:
cdef int py_function_result_to_lua(LuaRuntime runtime, lua_State *L, object o) except -1:
if runtime._unpack_returned_tuples and isinstance(o, tuple):
push_lua_arguments(runtime, L, <tuple>o)
return len(<tuple>o)
return _len_as_int(len(<tuple>o))
check_lua_stack(L, 1)
return py_to_lua(runtime, L, o)

Expand Down Expand Up @@ -1511,7 +1517,7 @@ cdef int py_to_lua_handle_overflow(LuaRuntime runtime, lua_State *L, object o) e
lua.lua_settop(L, old_top)
raise

cdef int py_to_lua(LuaRuntime runtime, lua_State *L, object o, bint wrap_none=False) except -1:
cdef int py_to_lua(LuaRuntime runtime, lua_State *L, object o, bint wrap_none=False, bint recursive=False, dict mapped_objs=None) except -1:
"""Converts Python object to Lua
Preconditions:
1 extra slot in the Lua stack
Expand Down Expand Up @@ -1563,13 +1569,19 @@ cdef int py_to_lua(LuaRuntime runtime, lua_State *L, object o, bint wrap_none=Fa
elif isinstance(o, float):
lua.lua_pushnumber(L, <lua.lua_Number><double>o)
pushed_values_count = 1
elif isinstance(o, _PyProtocolWrapper):
type_flags = (<_PyProtocolWrapper> o)._type_flags
o = (<_PyProtocolWrapper> o)._obj
pushed_values_count = py_to_lua_custom(runtime, L, o, type_flags)
elif recursive and isinstance(o, (list, dict, Sequence, Mapping)):
if mapped_objs is None:
mapped_objs = {}
table = py_to_lua_table(runtime, L, (o,), recursive=recursive, mapped_objs=mapped_objs)
(<_LuaObject> table).push_lua_object(L)
pushed_values_count = 1
else:
if isinstance(o, _PyProtocolWrapper):
type_flags = (<_PyProtocolWrapper>o)._type_flags
o = (<_PyProtocolWrapper>o)._obj
else:
# prefer __getitem__ over __getattr__ by default
type_flags = OBJ_AS_INDEX if hasattr(o, '__getitem__') else 0
# prefer __getitem__ over __getattr__ by default
type_flags = OBJ_AS_INDEX if hasattr(o, '__getitem__') else 0
pushed_values_count = py_to_lua_custom(runtime, L, o, type_flags)
return pushed_values_count

Expand Down Expand Up @@ -1655,7 +1667,7 @@ cdef bytes _asciiOrNone(s):
return <bytes>s


cdef _LuaTable py_to_lua_table(LuaRuntime runtime, lua_State* L, items):
cdef _LuaTable py_to_lua_table(LuaRuntime runtime, lua_State* L, tuple items, bint recursive=False, dict mapped_objs=None):
"""
Create a new Lua table and add different kinds of values from the sequence 'items' to it.

Expand All @@ -1666,14 +1678,24 @@ cdef _LuaTable py_to_lua_table(LuaRuntime runtime, lua_State* L, items):
check_lua_stack(L, 5)
old_top = lua.lua_gettop(L)
lua.lua_newtable(L)
cdef int lua_table_ref = lua.lua_gettop(L) # the index of the lua table which we are filling
# FIXME: how to check for failure?
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this comment was referring to the line above. And it's probably outdated, because we can't check for failures, Lua would just jump to the error handler.

Suggested change
cdef int lua_table_ref = lua.lua_gettop(L) # the index of the lua table which we are filling
# FIXME: how to check for failure?
# FIXME: handle allocation errors
cdef int lua_table_ref = lua.lua_gettop(L) # the index of the lua table which we are filling


if recursive and mapped_objs is None:
mapped_objs = {}
try:
for obj in items:
if recursive:
if id(obj) not in mapped_objs:
# this object is never seen before, we should cache it
mapped_objs[id(obj)] = lua_table_ref
else:
# this object has been cached, just get the corresponding lua table's index
idx = mapped_objs[id(obj)]
return new_lua_table(runtime, L, <int>idx)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, thanks for that comment, that actually helps. Then you're not using mapped_objs as it's named but rather as a mapped_tables, right? Seems worth renaming then to make the code clearer.

if isinstance(obj, dict):
for key, value in (<dict>obj).items():
py_to_lua(runtime, L, key, wrap_none=True)
py_to_lua(runtime, L, value)
py_to_lua(runtime, L, key, wrap_none=True, recursive=recursive, mapped_objs=mapped_objs)
py_to_lua(runtime, L, value, wrap_none=False, recursive=recursive, mapped_objs=mapped_objs)
lua.lua_rawset(L, -3)

elif isinstance(obj, _LuaTable):
Expand All @@ -1689,13 +1711,13 @@ cdef _LuaTable py_to_lua_table(LuaRuntime runtime, lua_State* L, items):
elif isinstance(obj, Mapping):
for key in obj:
value = obj[key]
py_to_lua(runtime, L, key, wrap_none=True)
py_to_lua(runtime, L, value)
py_to_lua(runtime, L, key, wrap_none=True, recursive=recursive, mapped_objs=mapped_objs)
py_to_lua(runtime, L, value, wrap_none=False, recursive=recursive, mapped_objs=mapped_objs)
lua.lua_rawset(L, -3)

else:
for arg in obj:
py_to_lua(runtime, L, arg)
py_to_lua(runtime, L, arg, wrap_none=False, recursive=recursive, mapped_objs=mapped_objs)
lua.lua_rawseti(L, -2, i)
i += 1

Expand Down Expand Up @@ -1826,7 +1848,7 @@ cdef object execute_lua_call(LuaRuntime runtime, lua_State *L, Py_ssize_t nargs)
lua.lua_replace(L, -2)
lua.lua_insert(L, 1)
has_lua_traceback_func = True
result_status = lua.lua_pcall(L, nargs, lua.LUA_MULTRET, has_lua_traceback_func)
result_status = lua.lua_pcall(L, <int>nargs, lua.LUA_MULTRET, has_lua_traceback_func)
if has_lua_traceback_func:
lua.lua_remove(L, 1)
results = unpack_lua_results(runtime, L)
Expand Down Expand Up @@ -2004,7 +2026,7 @@ cdef bint call_python(LuaRuntime runtime, lua_State *L, py_object* py_obj) excep
else:
args = ()
kwargs = {}

for i in range(nargs):
arg = py_from_lua(runtime, L, i+2)
if isinstance(arg, _PyArguments):
Expand Down
69 changes: 65 additions & 4 deletions lupa/tests/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ def get_attr(obj, name):


class TestLuaRuntime(SetupLuaRuntimeMixin, LupaTestCase):
def assertLuaResult(self, lua_expression, result):
self.assertEqual(self.lua.eval(lua_expression), result)

def test_lua_version(self):
version = self.lua.lua_version
self.assertEqual(tuple, type(version))
Expand Down Expand Up @@ -598,10 +601,16 @@ def test_table_from_bad(self):
self.assertRaises(TypeError, self.lua.table_from, None)
self.assertRaises(TypeError, self.lua.table_from, {"a": 5}, 123)

# def test_table_from_nested(self):
# table = self.lua.table_from({"obj": {"foo": "bar"}})
# lua_type = self.lua.eval("type")
# self.assertEqual(lua_type(table["obj"]), "table")
def test_table_from_nested(self):
table = self.lua.table_from([[3, 3, 3]], recursive=True)
Comment on lines +604 to +605
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may not have been the best example, just something that I came up with trying to understand your implementation. However, a more deeply nested example would be good, to make sure that, say, a dict of lists of dicts of dicts of lists of lists also works, or a (list) table of (mapping) tables. You're only testing one level of nesting, always starting with a dict. Maybe you can generate a data structure (or a few), map it back and forth, and then compare the result to itself? That would allow testing something larger and deeper.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Accutally there are other tests which are far more deeper, for instance, the test_table_from_self_ref_obj, which is a infinite deep structure.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't so much referring to the depth, rather different combinations of nestings. It might make a difference whether you're mapping a list of dicts or a dict of lists. It might make a difference whether you're mapping a list of simple values or a list of lists of lists. The tests should cover different combinations. Thus, generating data structures and validating the mapping generically would be great, because it would allow quickly testing a larger range of possible structures, making sure that we're not missing edge cases, and making sure we're touching all mapping code. We're probably not covering all of it currently.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I'll try to cover more tests, like List[Dict], Dict[str, Dict], List[List], Dict[str, List]```

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about something like this:

def test():
    from itertools import count
    def make_ds(*children):
        for child in children:
            yield child
        yield list(children)
        yield dict(zip(count(), children))
        yield {chr(ord('A') + i): child for i, child in enumerate(children)}

    for ds1 in make_ds(1, 2, 'x', 'y'):
        for ds2 in make_ds(ds1):
            for ds in make_ds(ds1, ds2):
                table = self.lua.table_from(ds)
                # validate table == ds

I hope it's mostly clear, the idea is to generate 0-2 level data structures of all sorts of nesting combinations, and then validate that the result looks the same in Lua as in Python. The validation would probably involve a couple of type checks and back-mappings Lua->Python, but shouldn't be too complex. The above loops generate somewhat redundant results, AFAICT, but not too bad. Maybe you can come up with something smarter even.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Opps... I found a bug, and it can be reproduced in master branch, just use lua.table_from(1), in this way py_to_lua_table got a (1, ), and as it iterate over the tuple, the PyLongObject is passed to

 else:
       for arg in obj:

but int is not iterable. Should fix it somehow so that I can continus with my pr.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The structures generated by this function should also be applied to master branch I think.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By the way, there might be some gc problem with cpython itself, in the loop test there are some lua tables can't match the corresponding pyobject, but when I test them in a individual test, the problem just gone.

Copy link
Owner

@scoder scoder Feb 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My test code proposal was buggy. Here's a revised version that does not generate single elements:

def test():
    from itertools import count
    def make_ds(*children):
        yield list(children)
        yield dict(zip(count(), children))
        yield {chr(ord('A') + i): child for i, child in enumerate(children)}

    elements = [1, 2, 'x', 'y']
    for ds1 in make_ds(*elements):
        for ds2 in make_ds(ds1):
            for ds3 in make_ds(ds1, elements, ds2):
                for ds in make_ds(ds1, ds2, ds3):
                    table = self.lua.table_from(ds)
                    # validate table == ds

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this test to the master branch and merged it into this PR. Since the "validation" in the master branch is really rudimentary, it hopefully won't fail here, but can now be adapted to validate the transitive translation.

self.lua.globals()["data"] = table
self.assertLuaResult("data[1][1]", 3)
self.assertLuaResult("data[1][2]", 3)
self.assertLuaResult("data[1][3]", 3)
self.assertLuaResult("type(data)", "table")
self.assertLuaResult("type(data[1])", "table")
self.assertLuaResult("#data", 1)
self.assertLuaResult("#data[1]", 3)

def test_table_from_table(self):
table1 = self.lua.eval("{3, 4, foo='bar'}")
Expand Down Expand Up @@ -632,6 +641,58 @@ def test_table_from_table_iter_indirect(self):
self.assertEqual(list(table2.keys()), [1, 2, 3])
self.assertEqual(set(table2.values()), set([1, 2, "foo"]))

def test_table_from_nested_dict(self):
data = {"a": {"a": "foo"}, "b": {"b": "bar"}}
table = self.lua.table_from(data, recursive=True)
self.assertEqual(table["a"]["a"], "foo")
self.assertEqual(table["b"]["b"], "bar")
self.lua.globals()["data"] = table
self.assertLuaResult("data.a.a", "foo")
self.assertLuaResult("data.b.b", "bar")
self.assertLuaResult("type(data.a)", "table")
self.assertLuaResult("type(data.b)", "table")

def test_table_from_nested_list(self):
data = {"a": {"a": "foo"}, "b": [1, 2, 3]}
table = self.lua.table_from(data, recursive=True)
self.assertEqual(table["a"]["a"], "foo")
self.assertEqual(table["b"][1], 1)
self.assertEqual(table["b"][2], 2)
self.assertEqual(table["b"][3], 3)
self.lua.globals()["data"] = table
self.assertLuaResult("data.a.a", "foo")
self.assertLuaResult("#data.b", 3)
self.lua.eval("assert(#data.b==3, 'failed')")
self.assertLuaResult("type(data.a)", "table")
self.assertLuaResult("type(data.b)", "table")

def test_table_from_nested_list_bad(self):
data = {"a": {"a": "foo"}, "b": [1, 2, 3]}
table = self.lua.table_from(data) # in this case, lua will get userdata instead of table
self.assertEqual(table["a"]["a"], "foo")
self.assertEqual(list(table["b"]), [1, 2, 3])
self.assertEqual(table["b"][0], 1)
self.assertEqual(table["b"][1], 2)
self.assertEqual(table["b"][2], 3)
self.lua.globals()["data"] = table
self.assertLuaResult("type(data.a)", "userdata")
self.assertLuaResult("type(data.b)", "userdata")

def test_table_from_self_ref_obj(self):
data = {}
data["key"] = data
l = []
l.append(l)
data["list"] = l
table = self.lua.table_from(data, recursive=True)
self.lua.globals()["data"] = table
self.assertLuaResult("type(data)", 'table')
self.assertLuaResult("type(data['key'])",'table')
self.assertLuaResult("type(data['list'])",'table')
self.assertLuaResult("data['list']==data['list'][1]", True)
self.assertLuaResult("type(data['key']['key']['key']['key'])", 'table')
self.assertLuaResult("type(data['key']['key']['key']['key']['list'])", 'table')

# FIXME: it segfaults
# def test_table_from_generator_calling_lua_functions(self):
# func = self.lua.eval("function (obj) return obj end")
Expand Down