Skip to content
This repository has been archived by the owner on Dec 10, 2018. It is now read-only.

Commit

Permalink
Merge pull request #22 from noodle4u/load_module, close #22
Browse files Browse the repository at this point in the history
more standard thrift sdk module generator
  • Loading branch information
lxyu committed Aug 5, 2014
2 parents 90b0bd9 + c20b0fc commit 8071a56
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 29 deletions.
5 changes: 3 additions & 2 deletions thriftpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import sys
__python__ = sys.version

__all__ = ["install_import_hook", "remove_import_hook", "load"]
__all__ = ["install_import_hook", "remove_import_hook", "load", "load_file",
"load_module"]

from .hook import install_import_hook, remove_import_hook
from .parser import load
from .parser import load, load_file, load_module
21 changes: 3 additions & 18 deletions thriftpy/hook.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
# -*- coding: utf-8 -*-

import os
import sys

from .parser import load
from .parser import load_module


class ThriftImporter(object):
Expand All @@ -20,24 +19,10 @@ def find_module(self, fullname, path=None):
return self

def load_module(self, fullname):
if '.' in fullname:
module_name, thrift_file = fullname.rsplit('.', 1)
module = self._import_module(module_name)
path_prefix = os.path.dirname(os.path.abspath(module.__file__))
path = os.path.join(path_prefix, thrift_file)
else:
path = fullname
filename = path.replace('_thrift', '.thrift', 1)
thrift = load(filename)
sys.modules[fullname] = thrift
thrift = load_module(fullname)
return thrift

def _import_module(self, import_name):
if '.' in import_name:
module, obj = import_name.rsplit('.', 1)
return getattr(__import__(module, None, None, [obj]), obj)
else:
return __import__(import_name)

_imp = ThriftImporter()


Expand Down
70 changes: 61 additions & 9 deletions thriftpy/parser.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# -*- coding: utf-8 -*-
# flake8: noqa

import functools
import hashlib
import itertools
import os.path
import os
import pickle
import sys
import types

from .thrift import TType, TPayload, TException
Expand Down Expand Up @@ -105,19 +105,29 @@ def parse(schema):

return result

from collections import OrderedDict

def load(thrift_file, cache=True):
def load_file(thrift_file, cache=True, module_name=None):
"""Load thrift_file as a module, default use cache to accelerate
tokenize processing.
Set cache to False if you don't want to load from cache.
the result is a none standard python module,we can't pickle it,
use load_module to get a pickleble thrift_schema
"""
global _thriftloader
if thrift_file in _thriftloader:
return _thriftloader[thrift_file]
_thriftloader_key = {
'thrift_file': thrift_file,
'module_name': module_name,
}
_thriftloader_key = str(OrderedDict(_thriftloader_key))
if _thriftloader_key in _thriftloader:
return _thriftloader[_thriftloader_key]

basename = os.path.basename(thrift_file)
module_name, _ = os.path.splitext(basename)
if not module_name:
basename = os.path.basename(thrift_file)
module_name, _ = os.path.splitext(basename)

with open(thrift_file, "r") as fp:
schema = fp.read()
Expand Down Expand Up @@ -275,6 +285,48 @@ def _ttype_spec(ttype, name):

setattr(service_cls, "thrift_services", thrift_services)
setattr(thrift_schema, service.name, service_cls)
thrift_schema.__file__ = thrift_file

_thriftloader[_thriftloader_key] = thrift_schema
return _thriftloader[_thriftloader_key]


def _import_module(import_name):
if '.' in import_name:
module, obj = import_name.rsplit('.', 1)
return getattr(__import__(module, None, None, [obj]), obj)
else:
return __import__(import_name)


def _gen_path_from_module_name(module_name):
if '.' in module_name:
module_name, thrift_file = module_name.rsplit('.', 1)
module = _import_module(module_name)
path_prefix = os.path.dirname(os.path.abspath(module.__file__))
path = os.path.join(path_prefix, thrift_file)
else:
path = module_name
_path = list(path)
_path[-7] = '.'
filename = ''.join(_path)
# filename = path.replace('_thrift', '.thrift', 1)
return filename


def load_module(module_name, cache=True):
"""
:param module_name:
:param cache:
:return:
thrift_file must be a subpath of any path in sys.path
load thrift_file as a standard python module.then we can pickle.dumps
thrift content
"""
thrift_file = _gen_path_from_module_name(module_name)
module = load_file(thrift_file, cache=cache, module_name=module_name)
sys.modules[module_name] = module
return module

_thriftloader[thrift_file] = thrift_schema
return _thriftloader[thrift_file]
# backwards compatible
load = load_file

0 comments on commit 8071a56

Please sign in to comment.