Skip to content

Commit

Permalink
Can pass tables to lua functions now ,from python
Browse files Browse the repository at this point in the history
  • Loading branch information
hughperkins committed Mar 4, 2016
1 parent 7522c32 commit c1a8b0c
Show file tree
Hide file tree
Showing 12 changed files with 284 additions and 224 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,11 @@ from __future__ import print_function, division

# Recent news

5 March:
* added `PyTorchHelpers.load_lua_class(lua_filename, lua_classname)` to easily import a lua class from a lua file
* can pass parameters to lua class constructors, from python
* can pass tables to lua functions, from python (pass in as python dictionaries, become lua tables)

2 March:
* removed requirements on Cython, Jinja2 for installation

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def get_file_datetime(filepath):

setup(
name='PyTorch',
version='2.7.0',
version='2.8.0',
author='Hugh Perkins',
author_email='[email protected]',
description=(
Expand Down
6 changes: 6 additions & 0 deletions simpleexample/luabit.lua
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,9 @@ function Luabit:getOut(inTensor, outSize, kernelSize)
return out
end

function Luabit:printTable(sometable)
for k, v in pairs(sometable) do
print('Luabit:printTable ', k, v)
end
end

2 changes: 2 additions & 0 deletions simpleexample/pybit.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,5 @@
outTensor = luaout.asNumpyTensor()
print('outTensor', outTensor)

luabit.printTable({'color': 'red', 'weather': 'sunny', 'anumber': 10, 'afloat': 1.234})

73 changes: 21 additions & 52 deletions src/PyTorch.cpp

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

66 changes: 37 additions & 29 deletions src/PyTorchAug.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,21 +68,50 @@ def torchType(lua, pos):
lua.call(1, 1)
return popString(lua)

def pushSomething(lua, something):
if isinstance(something, int):
lua.pushNumber(something)
return

if isinstance(something, float):
lua.pushNumber(something)
return

if isinstance(something, str):
lua.pushString(something)
return

if isinstance(something, dict):
pushTable(lua, something)
return

for pythonClass in pushFunctionByPythonClass:
if isinstance(something, pythonClass):
pushFunctionByPythonClass[pythonClass](something)
return

if type(something) in luaClassesReverse:
pushObject(lua, something)
return

raise Exception('pushing type ' + str(type(something)) + ' not implemented, value ', something)

def pushTable(lua, table):
lua.newTable()
for k, v in table.items():
pushSomething(lua, k)
pushSomething(lua, v)
lua.setTable(3)

class LuaClass(object):
def __init__(self, *args, nameList):
# print('LuaClass.__init__()')
lua = PyTorch.getGlobalState().getLua()
# self.luaclass = luaclass
self.__dict__['__objectId'] = getNextObjectId()
topStart = lua.getTop()
pushGlobalFromList(lua, nameList)
for arg in args:
if isinstance(arg, int):
lua.pushNumber(arg)
elif isinstance(arg, str):
lua.pushString(arg)
else:
raise Exception('arg type ' + str(type(arg)) + ' not implemented')
pushSomething(lua, arg)
lua.call(len(args), 1)
registerObject(lua, self)

Expand Down Expand Up @@ -160,29 +189,8 @@ def mymethod(*args):
pushObject(lua, self)
lua.getField(-1, name)
lua.insert(-2)
# pushObject(lua, self)
for arg in args:
# print('arg', arg, type(arg))
pushedArg = False
for pythonClass in pushFunctionByPythonClass:
if isinstance(arg, pythonClass):
pushFunctionByPythonClass[pythonClass](arg)
pushedArg = True
break
if not pushedArg and type(arg) in luaClassesReverse:
pushObject(lua, arg)
pushedArg = True
if not pushedArg and isinstance(arg, float):
lua.pushNumber(arg)
pushedArg = True
if not pushedArg and isinstance(arg, int):
lua.pushNumber(arg)
pushedArg = True
if not pushedArg and isinstance(arg, str):
lua.pushString(arg)
pushedArg = True
if not pushedArg:
raise Exception('arg type ' + str(type(arg)) + ' not implemented')
pushSomething(lua, arg)
lua.call(len(args) + 1, 1) # +1 for self
lua.pushValue(-1)
pushGlobal(lua, 'torch', 'type')
Expand Down
34 changes: 2 additions & 32 deletions src/Storage.cpp

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit c1a8b0c

Please sign in to comment.