diff --git a/tests/test_base.py b/tests/test_base.py index c67dd6e..474b9fa 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -1,7 +1,5 @@ # -*- coding: utf-8 -*- -import linecache - import pytest import thriftpy @@ -70,8 +68,3 @@ def test_parse_spec(): for spec, res in cases: assert parse_spec(*spec) == res - - -def test_init_func(): - thriftpy.load("addressbook.thrift") - assert linecache.getline('', 1) != '' diff --git a/tests/test_type.py b/tests/test_type.py index b4e27ff..8df4559 100644 --- a/tests/test_type.py +++ b/tests/test_type.py @@ -1,10 +1,61 @@ # -*- coding: utf-8 -*- from thriftpy import load -from thriftpy.thrift import TType +from thriftpy.thrift import TType, TPayload def test_set(): s = load("type.thrift") assert s.Set.thrift_spec == {1: (TType.SET, "a_set", TType.STRING, True)} + + +class Struct(TPayload): + thrift_spec = { + 1: (TType.MAP, 'tdict', (TType.I32, TType.I32), False), + 2: (TType.SET, 'tset', TType.I32, False), + 3: (TType.LIST, 'tlist', TType.I32, False), + } + + default_spec = [ + ('tdict', {}), + ('tset', set()), + ('tlist', []), + ] + + +# making an object with default values and then mutating the object should not +# change the default values + +def test_mutable_default_dict(): + s1 = Struct() + s1.tdict[1] = 2 + + s2 = Struct() + assert s2.tdict == {} + + +def test_mutable_default_list(): + s1 = Struct() + s1.tlist.append(1) + + s2 = Struct() + assert s2.tlist == [] + + +def test_mutable_default_set(): + s1 = Struct() + s1.tset.add(1) + + s2 = Struct() + assert s2.tset == set() + + +def test_positional_args(): + # thriftpy instantiates TPayload objects using positional args. + # thriftpy.thrift.TException being the most notable example. + # make sure that we don't break backwards compatiblity + s1 = Struct({1: 2}, set([3, 4]), [5, 6]) + assert s1.tdict == {1: 2} + assert s1.tset == set([3, 4]) + assert s1.tlist == [5, 6] diff --git a/thriftpy/thrift.py b/thriftpy/thrift.py index bf1db20..81efc2d 100644 --- a/thriftpy/thrift.py +++ b/thriftpy/thrift.py @@ -9,10 +9,8 @@ from __future__ import absolute_import +import copy import functools -import linecache -import types - from ._compat import with_metaclass @@ -54,25 +52,28 @@ def __init__(self, name='Alice', number=None): self.number = number """ if not spec: - def __init__(self): - pass - return __init__ - - varnames, defaults = zip(*spec) - - args = ', '.join(map('{0[0]}={0[1]!r}'.format, spec)) - init = "def __init__(self, {0}):\n".format(args) - init += "\n".join(map(' self.{0} = {0}'.format, varnames)) - - name = ''.format(cls.__name__) - code = compile(init, name, 'exec') - func = next(c for c in code.co_consts if isinstance(c, types.CodeType)) - - # Add a fake linecache entry so debuggers and the traceback module can - # better understand our generated code. - linecache.cache[name] = (len(init), None, init.splitlines(True), name) - - return types.FunctionType(func, {}, argdefs=defaults) + spec = [] + + def __init__(self, *args, **kwargs): + # __init__ might get passed args or kwargs assume that positional args + # are in the same order specified by spec anything else is a kwarg + i = len(args) + arg_spec = spec[:i] + kw_spec = spec[i:] + + for arg, (name, _) in zip(args, arg_spec): + setattr(self, name, arg) + + for name, default in kw_spec: + if name in kwargs: + setattr(self, name, kwargs.pop(name)) + else: + # make a copy of the default values so that we can mutate them + # without affecting anything else + setattr(self, name, copy.copy(default)) + assert not kwargs + + return __init__ class TType(object):