Skip to content

Commit

Permalink
Simplify loading lua classes: add `PyTorchHelpers.load_lua_class(lua_…
Browse files Browse the repository at this point in the history
…filename, lua_classname)`
  • Loading branch information
hughperkins committed Mar 4, 2016
1 parent dd5dd61 commit 96b1bae
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 12 deletions.
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def get_file_datetime(filepath):

setup(
name='PyTorch',
version='2.5.0',
version='2.6.0',
author='Hugh Perkins',
author_email='[email protected]',
description=(
Expand All @@ -152,7 +152,7 @@ def get_file_datetime(filepath):
install_requires=['numpy'],
scripts=[],
ext_modules=ext_modules,
py_modules=['floattensor', 'PyTorchAug'],
py_modules=['floattensor', 'PyTorchAug', 'PyTorchHelpers'],
package_dir={'': 'src'}
)

12 changes: 2 additions & 10 deletions simpleexample/pybit.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,10 @@
import sys
import os
import PyTorchAug
import PyTorch
import PyTorchHelpers
import numpy as np

PyTorch.require('luabit')
class Luabit(PyTorchAug.LuaClass):
def __init__(self, _fromLua=False):
self.luaclass = 'Luabit'
if not _fromLua:
name = self.__class__.__name__
super(self.__class__, self).__init__([name])
else:
self.__dict__['__objectId'] = getNextObjectId()
Luabit = PyTorchHelpers.load_lua_class('luabit.lua', 'Luabit')

batchSize = 2
numFrames = 4
Expand Down
16 changes: 16 additions & 0 deletions src/PyTorchHelpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import PyTorch
import PyTorchAug

def load_lua_class(lua_filename, lua_classname):
module = lua_filename.replace('.lua', '')
PyTorch.require(module)
class LuaWrapper(PyTorchAug.LuaClass):
def __init__(self, _fromLua=False):
self.luaclass = lua_classname
if not _fromLua:
name = lua_classname
super(self.__class__, self).__init__([name])
else:
self.__dict__['__objectId'] = getNextObjectId()
return LuaWrapper

0 comments on commit 96b1bae

Please sign in to comment.