diff --git a/src/apps/monero/controller/misc.py b/src/apps/monero/controller/misc.py index dc58f2060..2b31a340e 100644 --- a/src/apps/monero/controller/misc.py +++ b/src/apps/monero/controller/misc.py @@ -58,36 +58,32 @@ async def monero_get_creds(ctx, address_n=None, network_type=None): return creds -def parse_msg(bts, msg): - from apps.monero.xmr.serialize import xmrserialize +def parse_msg(bts, msg_type): from apps.monero.xmr.serialize.readwriter import MemoryReaderWriter reader = MemoryReaderWriter(memoryview(bts)) - ar = xmrserialize.Archive(reader, False) - return ar.message(msg) + return msg_type.load(reader) -def dump_msg(msg, preallocate=None, msg_type=None, prefix=None): - from apps.monero.xmr.serialize import xmrserialize +def dump_msg(msg, preallocate=None, prefix=None): from apps.monero.xmr.serialize.readwriter import MemoryReaderWriter writer = MemoryReaderWriter(preallocate=preallocate) if prefix: writer.write(prefix) - ar = xmrserialize.Archive(writer, True) - ar.message(msg, msg_type=msg_type) + msg_type = msg.__class__ + msg_type.dump(writer, msg) return writer.get_buffer() -def dump_msg_gc(msg, preallocate=None, msg_type=None, del_msg=False): - b = dump_msg(msg, preallocate=preallocate, msg_type=msg_type) - if del_msg: - del msg +def dump_msg_gc(msg, preallocate=None, prefix=None): + buf = dump_msg(msg, preallocate=preallocate, prefix=None) + del msg import gc gc.collect() - return b + return buf def dump_rsig_bp(rsig): diff --git a/src/apps/monero/protocol/signing/step_06_set_out1.py b/src/apps/monero/protocol/signing/step_06_set_out1.py index 313222b54..a1e2e7e4d 100644 --- a/src/apps/monero/protocol/signing/step_06_set_out1.py +++ b/src/apps/monero/protocol/signing/step_06_set_out1.py @@ -209,7 +209,7 @@ def _range_proof(state, idx, amount, rsig_data=None): _get_out_mask(state, 1 + idx - batch_size + ix) for ix in range(batch_size) ] - bp_obj = misc.parse_msg(rsig_data.rsig, Bulletproof()) + bp_obj = misc.parse_msg(rsig_data.rsig, Bulletproof) rsig_data.rsig = None # BP is hashed with raw=False as hash does not contain L, R diff --git a/src/apps/monero/protocol/signing/step_09_sign_input.py b/src/apps/monero/protocol/signing/step_09_sign_input.py index 24339cec9..56bc6cbe1 100644 --- a/src/apps/monero/protocol/signing/step_09_sign_input.py +++ b/src/apps/monero/protocol/signing/step_09_sign_input.py @@ -173,5 +173,5 @@ async def sign_input( ) return MoneroTransactionSignInputAck( - signature=misc.dump_msg_gc(mgs[0], preallocate=488, del_msg=True), cout=cout + signature=misc.dump_msg_gc(mgs[0], preallocate=488), cout=cout ) diff --git a/src/apps/monero/xmr/serialize/base_types.py b/src/apps/monero/xmr/serialize/base_types.py index 30cc0ec4e..b76c10b72 100644 --- a/src/apps/monero/xmr/serialize/base_types.py +++ b/src/apps/monero/xmr/serialize/base_types.py @@ -1,15 +1,35 @@ +from apps.monero.xmr.serialize.int_serialize import ( + dump_uint, + dump_uvarint, + load_uint, + load_uvarint, +) + + class XmrType: pass class UVarintType(XmrType): - pass + @staticmethod + def load(reader) -> int: + return load_uvarint(reader) + + @staticmethod + def dump(writer, n: int): + return dump_uvarint(writer, n) class IntType(XmrType): WIDTH = 0 - SIGNED = 0 - VARIABLE = 0 + + @classmethod + def load(cls, reader) -> int: + return load_uint(reader, cls.WIDTH) + + @classmethod + def dump(cls, writer, n: int): + return dump_uint(writer, n, cls.WIDTH) class UInt8(IntType): diff --git a/src/apps/monero/xmr/serialize/message_types.py b/src/apps/monero/xmr/serialize/message_types.py index 62261236f..f5cbbd5bf 100644 --- a/src/apps/monero/xmr/serialize/message_types.py +++ b/src/apps/monero/xmr/serialize/message_types.py @@ -1,116 +1,163 @@ +from trezor.utils import obj_eq, obj_repr + from apps.monero.xmr.serialize.base_types import XmrType -from apps.monero.xmr.serialize.obj_helper import eq_obj_contents, is_type, slot_obj_dict +from apps.monero.xmr.serialize.int_serialize import ( + dump_uint, + dump_uvarint, + load_uint, + load_uvarint, +) -class BlobType(XmrType): +class UnicodeType(XmrType): """ - Binary data - - Represented as bytearray() or a list of values in data structures. - Not wrapped in the BlobType, the BlobType is only a scheme descriptor. - Behaves in the same way as primitive types - - Supports also the wrapped version (__init__, DATA_ATTR, eq, repr...), + Unicode data in UTF-8 encoding. """ - DATA_ATTR = "data" - FIX_SIZE = 0 - SIZE = 0 - - def __eq__(self, rhs): - return eq_obj_contents(self, rhs) - - def __repr__(self): - dct = slot_obj_dict(self) if hasattr(self, "__slots__") else self.__dict__ - return "<%s: %s>" % (self.__class__.__name__, dct) + @staticmethod + def dump(writer, s): + dump_uvarint(writer, len(s)) + writer.write(bytes(s)) - -class UnicodeType(XmrType): - pass + @staticmethod + def load(reader): + ivalue = load_uvarint(reader) + fvalue = bytearray(ivalue) + reader.readinto(fvalue) + return str(fvalue) -class VariantType(XmrType): +class BlobType(XmrType): """ - Union of types, variant tags needed. is only one of the types. List in typedef, enum. - Wraps the variant type in order to unambiguously support variant of variants. - Supports also unwrapped value using type system to distinguish variants - simplifies the construction. + Binary data, represented as bytearray. BlobType is only a scheme + descriptor. Behaves in the same way as primitive types. """ - WRAPS_VALUE = False - - def __init__(self): - self.variant_elem = None - self.variant_elem_type = None + FIX_SIZE = 0 + SIZE = 0 @classmethod - def f_specs(cls): - return () + def dump(cls, writer, elem: bytes): + if cls.FIX_SIZE: + if cls.SIZE != len(elem): + raise ValueError("Size mismatch") + else: + dump_uvarint(writer, len(elem)) + writer.write(elem) - def set_variant(self, fname, fvalue): - self.variant_elem = fname - self.variant_elem_type = fvalue.__class__ - setattr(self, fname, fvalue) - - def __eq__(self, rhs): - return eq_obj_contents(self, rhs) - - def __repr__(self): - dct = slot_obj_dict(self) if hasattr(self, "__slots__") else self.__dict__ - return "<%s: %s>" % (self.__class__.__name__, dct) + @classmethod + def load(cls, reader) -> bytearray: + if cls.FIX_SIZE: + size = cls.SIZE + else: + size = load_uvarint(reader) + elem = bytearray(size) + reader.readinto(elem) + return elem class ContainerType(XmrType): """ - Array of elements - Represented as a real array in the data structures, not wrapped in the ContainerType. - The Container type is used only as a schema descriptor for serialization. + Array of elements, represented as a list of items. ContainerType is only a + scheme descriptor. """ FIX_SIZE = 0 SIZE = 0 ELEM_TYPE = None + @classmethod + def dump(cls, writer, elems, elem_type=None): + if elem_type is None: + elem_type = cls.ELEM_TYPE + if cls.FIX_SIZE: + if cls.SIZE != len(elems): + raise ValueError("Size mismatch") + else: + dump_uvarint(writer, len(elems)) + for elem in elems: + elem_type.dump(writer, elem) -class MessageType(XmrType): - def __init__(self, **kwargs): - for kw in kwargs: - setattr(self, kw, kwargs[kw]) + @classmethod + def load(cls, reader, elem_type=None): + if elem_type is None: + elem_type = cls.ELEM_TYPE + if cls.FIX_SIZE: + size = cls.SIZE + else: + size = load_uvarint(reader) + elems = [] + for _ in range(size): + elem = elem_type.load(reader) + elems.append(elem) + return elems + + +class VariantType(XmrType): + """ + Union of types, differentiated by variant tags. VariantType is only a scheme + descriptor. + """ + + @classmethod + def dump(cls, writer, elem): + for field in cls.f_specs(): + ftype = field[1] + if isinstance(elem, ftype): + break + else: + raise ValueError("Unrecognized variant: %s" % elem) - def __eq__(self, rhs): - return eq_obj_contents(self, rhs) + dump_uint(writer, ftype.VARIANT_CODE, 1) + ftype.dump(writer, elem) - def __repr__(self): - dct = slot_obj_dict(self) if hasattr(self, "__slots__") else self.__dict__ - return "<%s: %s>" % (self.__class__.__name__, dct) + @classmethod + def load(cls, reader): + tag = load_uint(reader, 1) + for field in cls.f_specs(): + ftype = field[1] + if ftype.VARIANT_CODE == tag: + fvalue = ftype.load(reader) + break + else: + raise ValueError("Unknown tag: %s" % tag) + return fvalue @classmethod def f_specs(cls): return () -def container_elem_type(container_type, params): +class MessageType(XmrType): """ - Returns container element type + Message composed of fields with specific types. """ - elem_type = params[0] if params else None - if elem_type is None: - elem_type = container_type.ELEM_TYPE - return elem_type + def __init__(self, **kwargs): + for kw in kwargs: + setattr(self, kw, kwargs[kw]) -def gen_elem_array(size, elem_type=None): - """ - Generates element array of given size and initializes with given type. - Supports container type, used for pre-allocation before deserialization. - """ - if elem_type is None or not callable(elem_type): - return [elem_type] * size - if is_type(elem_type, ContainerType): + __eq__ = obj_eq + __repr__ = obj_repr - def elem_type(): - return [] + @classmethod + def dump(cls, writer, msg): + defs = cls.f_specs() + for field in defs: + fname, ftype, *fparams = field + fvalue = getattr(msg, fname, None) + ftype.dump(writer, fvalue, *fparams) - res = [] - for _ in range(size): - res.append(elem_type()) - return res + @classmethod + def load(cls, reader): + msg = cls() + defs = cls.f_specs() + for field in defs: + fname, ftype, *fparams = field + fvalue = ftype.load(reader, *fparams) + setattr(msg, fname, fvalue) + return msg + + @classmethod + def f_specs(cls): + return () diff --git a/src/apps/monero/xmr/serialize/obj_helper.py b/src/apps/monero/xmr/serialize/obj_helper.py deleted file mode 100644 index ea5170758..000000000 --- a/src/apps/monero/xmr/serialize/obj_helper.py +++ /dev/null @@ -1,49 +0,0 @@ -def eq_obj_slots(l, r): - """ - Compares objects with __slots__ defined - """ - for f in l.__slots__: - if getattr(l, f, None) != getattr(r, f, None): - return False - return True - - -def eq_obj_contents(l, r): - """ - Compares object contents, supports slots - """ - if l.__class__ is not r.__class__: - return False - if hasattr(l, "__slots__"): - return eq_obj_slots(l, r) - else: - return l.__dict__ == r.__dict__ - - -def slot_obj_dict(o): - """ - Builds dict for o with __slots__ defined - """ - d = {} - for f in o.__slots__: - d[f] = getattr(o, f, None) - return d - - -def is_type(x, types, full=False): - """ - Returns true if x is of type in types tuple - """ - types = types if isinstance(types, tuple) else (types,) - ins = isinstance(x, types) - sub = False - try: - sub = issubclass(x, types) - except Exception: - pass - res = ins or sub - return res if not full else (res, ins) - - -def get_ftype_params(field): - return field[1], field[2:] diff --git a/src/apps/monero/xmr/serialize/xmrserialize.py b/src/apps/monero/xmr/serialize/xmrserialize.py deleted file mode 100644 index 03a5cfeb1..000000000 --- a/src/apps/monero/xmr/serialize/xmrserialize.py +++ /dev/null @@ -1,386 +0,0 @@ -''' -Minimal streaming codec for a Monero binary serialization. -Used for a binary serialization in blockchain and for hash computation for signatures. - -Equivalent of BEGIN_SERIALIZE_OBJECT(), /src/serialization/serialization.h - -- The wire binary format does not use tags. Structure has to be read from the binary stream -with the scheme specified in order to parse the structure. - -- Heavily uses variable integer serialization - similar to the UTF8 or LZ4 number encoding. - -- Supports: blob, string, integer types - variable or fixed size, containers of elements, - variant types, messages of elements - -For de-serializing (loading) types, object with `AsyncReader` -interface is required: - ->>> class AsyncReader: ->>> async def areadinto(self, buffer): ->>> """ ->>> Reads `len(buffer)` bytes into `buffer`, or raises `EOFError`. ->>> """ - -For serializing (dumping) types, object with `AsyncWriter` interface is -required: - ->>> class AsyncWriter: ->>> async def awrite(self, buffer): ->>> """ ->>> Writes all bytes from `buffer`, or raises `EOFError`. ->>> """ -''' - -from apps.monero.xmr.serialize.base_types import IntType, UVarintType, XmrType -from apps.monero.xmr.serialize.int_serialize import ( - dump_uint, - dump_uvarint, - load_uint, - load_uvarint, -) -from apps.monero.xmr.serialize.message_types import ( - BlobType, - ContainerType, - MessageType, - UnicodeType, - VariantType, - container_elem_type, -) - - -class Archive: - """ - Archive object for object binary serialization / deserialization. - Resembles Archive API from the Monero codebase or Boost serialization archive. - - The design goal is to provide uniform API both for serialization and deserialization - so the code is not duplicated for serialization and deserialization but the same - for both ways in order to minimize potential bugs in the code. - - In order to use the archive for both ways we have to use so-called field references - as we cannot directly modify given element as a parameter (value-passing) as its performed - in C++ code. see: eref(), get_elem(), set_elem() - """ - - def __init__(self, iobj, writing=True): - self.writing = writing - self.iobj = iobj - - def uvarint(self, elem): - """ - Uvarint - """ - if self.writing: - return dump_uvarint(self.iobj, elem) - else: - return load_uvarint(self.iobj) - - def uint(self, elem, elem_type=None, width=None): - """ - Fixed size int - """ - if self.writing: - return dump_uint(self.iobj, elem, width if width else elem_type.WIDTH) - else: - return load_uint(self.iobj, width if width else elem_type.WIDTH) - - def unicode_type(self, elem): - """ - Unicode type - """ - if self.writing: - return dump_unicode(self.iobj, elem) - else: - return load_unicode(self.iobj) - - def blob(self, elem=None, elem_type=None, params=None): - """ - Loads/dumps blob - """ - elem_type = elem_type if elem_type else elem.__class__ - if self.writing: - return dump_blob(self.iobj, elem, elem_type, params) - else: - return load_blob(self.iobj, elem_type, params, elem) - - def container(self, container=None, container_type=None, params=None): - """ - Loads/dumps container - """ - if self.writing: - return self._dump_container(container, container_type, params) - else: - return self._load_container(container_type, params, container) - - def container_size(self, container_len=None, container_type=None, params=None): - """ - Container size - """ - if self.writing: - return self._dump_container_size(container_len, container_type, params) - else: - raise ValueError("Not supported") - - def variant(self, elem=None, elem_type=None, params=None, wrapped=None): - """ - Loads/dumps variant type - """ - elem_type = elem_type if elem_type else elem.__class__ - if self.writing: - return self._dump_variant( - elem, elem_type if elem_type else elem.__class__, params - ) - else: - return self._load_variant( - elem_type if elem_type else elem.__class__, params, elem, wrapped - ) - - def message(self, msg, msg_type=None): - """ - Loads/dumps message - """ - msg_type = msg_type if msg_type is not None else msg.__class__ - if self.writing: - return self._dump_message(msg, msg_type) - else: - return self._load_message(msg_type, msg) - - def message_field(self, msg, field, fvalue=None): - """ - Dumps/Loads message field - """ - if self.writing: - return self._dump_message_field(msg, field, fvalue) - else: - return self._load_message_field(field) - - def _get_type(self, elem_type): - if issubclass(elem_type, XmrType): - return elem_type - else: - # Can happen due to unimport. - raise ValueError("XMR serialization hierarchy broken") - - def _is_type(self, elem_type, test_type): - return issubclass(elem_type, test_type) - - def field(self, elem=None, elem_type=None, params=None): - elem_type = elem_type if elem_type else elem.__class__ - fvalue = None - - etype = self._get_type(elem_type) - if self._is_type(etype, UVarintType): - fvalue = self.uvarint(elem) - - elif self._is_type(etype, IntType): - fvalue = self.uint(elem, elem_type) - - elif self._is_type(etype, BlobType): - fvalue = self.blob(elem, elem_type, params) - - elif self._is_type(etype, UnicodeType): - fvalue = self.unicode_type(elem) - - elif self._is_type(etype, VariantType): - fvalue = self.variant(elem, elem_type, params) - - elif self._is_type(etype, ContainerType): - fvalue = self.container(elem, elem_type, params) - - elif self._is_type(etype, MessageType): - fvalue = self.message(elem, elem_type) - - else: - raise TypeError( - "unknown type: %s %s %s" % (elem_type, type(elem_type), elem) - ) - - return fvalue - - def dump_field(self, elem, elem_type, params=None): - return self.field(elem, elem_type, params) - - def load_field(self, elem_type, params=None, elem=None): - return self.field(elem, elem_type, params) - - def _dump_container_size(self, container_len, container_type, params=None): - """ - Dumps container size - per element streaming - """ - if not container_type or not container_type.FIX_SIZE: - dump_uvarint(self.iobj, container_len) - elif container_len != container_type.SIZE: - raise ValueError( - "Fixed size container has not defined size: %s" % container_type.SIZE - ) - - def _dump_container(self, container, container_type, params=None): - """ - Dumps container of elements to the writer. - """ - self._dump_container_size(len(container), container_type) - - elem_type = container_elem_type(container_type, params) - - for elem in container: - self.dump_field(elem, elem_type, params[1:] if params else None) - - def _load_container(self, container_type, params=None, container=None): - """ - Loads container of elements from the reader. Supports the container ref. - Returns loaded container. - """ - - c_len = ( - container_type.SIZE if container_type.FIX_SIZE else load_uvarint(self.iobj) - ) - if container and c_len != len(container): - raise ValueError("Size mismatch") - - elem_type = container_elem_type(container_type, params) - res = container if container else [] - for i in range(c_len): - fvalue = self.load_field(elem_type, params[1:] if params else None) - res.append(fvalue) - return res - - def _dump_message_field(self, msg, field, fvalue=None): - """ - Dumps a message field to the writer. Field is defined by the message field specification. - """ - fname, ftype, params = field[0], field[1], field[2:] - fvalue = getattr(msg, fname, None) if fvalue is None else fvalue - self.dump_field(fvalue, ftype, params) - - def _load_message_field(self, field): - """ - Loads message field from the reader. Field is defined by the message field specification. - Returns loaded value, supports field reference. - """ - ftype, params = field[1], field[2:] - return self.load_field(ftype, params) - - def _dump_message(self, msg, msg_type=None): - """ - Dumps message to the writer. - """ - mtype = msg.__class__ if msg_type is None else msg_type - fields = mtype.f_specs() - for field in fields: - self._dump_message_field(msg=msg, field=field) - - def _load_message(self, msg_type, msg=None): - """ - Loads message if the given type from the reader. - Supports reading directly to existing message. - """ - msg = msg_type() if msg is None else msg - fields = msg_type.f_specs() if msg_type else msg.__class__.f_specs() - for field in fields: - fval = self._load_message_field(field) - setattr(msg, field[0], fval) - - return msg - - def _dump_variant(self, elem, elem_type=None, params=None): - """ - Dumps variant type to the writer. - Supports both wrapped and raw variant. - """ - if isinstance(elem, VariantType) or elem_type.WRAPS_VALUE: - dump_uint(self.iobj, elem.variant_elem_type.VARIANT_CODE, 1) - self.dump_field(getattr(elem, elem.variant_elem), elem.variant_elem_type) - - else: - fdef = find_variant_fdef(elem_type, elem) - dump_uint(self.iobj, fdef[1].VARIANT_CODE, 1) - self.dump_field(elem, fdef[1]) - - def _load_variant(self, elem_type, params=None, elem=None, wrapped=None): - """ - Loads variant type from the reader. - Supports both wrapped and raw variant. - """ - is_wrapped = ( - (isinstance(elem, VariantType) or elem_type.WRAPS_VALUE) - if wrapped is None - else wrapped - ) - if is_wrapped: - elem = elem_type() if elem is None else elem - - tag = load_uint(self.iobj, 1) - for field in elem_type.f_specs(): - ftype = field[1] - if ftype.VARIANT_CODE == tag: - fvalue = self.load_field( - ftype, field[2:], elem if not is_wrapped else None - ) - if is_wrapped: - elem.set_variant(field[0], fvalue) - return elem if is_wrapped else fvalue - raise ValueError("Unknown tag: %s" % tag) - - -def dump_blob(writer, elem, elem_type, params=None): - """ - Dumps blob message to the writer. - Supports both blob and raw value. - """ - elem_is_blob = isinstance(elem, BlobType) - elem_params = elem if elem_is_blob or elem_type is None else elem_type - data = bytes(getattr(elem, BlobType.DATA_ATTR) if elem_is_blob else elem) - - if not elem_params.FIX_SIZE: - dump_uvarint(writer, len(elem)) - elif len(data) != elem_params.SIZE: - raise ValueError("Fixed size blob has not defined size: %s" % elem_params.SIZE) - writer.write(data) - - -def load_blob(reader, elem_type, params=None, elem=None): - """ - Loads blob from reader to the element. Returns the loaded blob. - """ - ivalue = elem_type.SIZE if elem_type.FIX_SIZE else load_uvarint(reader) - fvalue = bytearray(ivalue) - reader.readinto(fvalue) - - if elem is None: - return fvalue # array by default - - elif isinstance(elem, BlobType): - setattr(elem, elem_type.DATA_ATTR, fvalue) - return elem - - else: - elem.extend(fvalue) - - return elem - - -def dump_unicode(writer, elem): - dump_uvarint(writer, len(elem)) - writer.write(bytes(elem, "utf8")) - - -def load_unicode(reader): - ivalue = load_uvarint(reader) - fvalue = bytearray(ivalue) - reader.readinto(fvalue) - return str(fvalue, "utf8") - - -def find_variant_fdef(elem_type, elem): - fields = elem_type.f_specs() - for x in fields: - if isinstance(elem, x[1]): - return x - - # Not direct hierarchy - name = elem.__class__.__name__ - for x in fields: - if name == x[1].__name__: - return x - - raise ValueError("Unrecognized variant: %s" % elem) diff --git a/src/trezor/utils.py b/src/trezor/utils.py index 0b505680a..79b71f2f4 100644 --- a/src/trezor/utils.py +++ b/src/trezor/utils.py @@ -75,3 +75,36 @@ def append(self, b: int): def get_digest(self) -> bytes: return self.ctx.digest() + + +def obj_eq(l, r): + """ + Compares object contents, supports __slots__. + """ + if l.__class__ is not r.__class__: + return False + if hasattr(l, "__slots__"): + return obj_slots_dict(l) == obj_slots_dict(r) + else: + return l.__dict__ == r.__dict__ + + +def obj_repr(o): + """ + Returns a string representation of object, supports __slots__. + """ + if hasattr(o, "__slots__"): + d = obj_slots_dict(o) + else: + d = o.__dict__ + return "<%s: %s>" % (o.__class__.__name__, d) + + +def obj_slots_dict(o): + """ + Builds dict for o from defined __slots__. + """ + d = {} + for f in o.__slots__: + d[f] = getattr(o, f, None) + return d diff --git a/tests/test_apps.monero.serializer.py b/tests/test_apps.monero.serializer.py index 95c278d5b..2dc0f2ea1 100644 --- a/tests/test_apps.monero.serializer.py +++ b/tests/test_apps.monero.serializer.py @@ -1,13 +1,18 @@ -from common import * import utest - +from common import * from trezor import log, loop, utils -from apps.monero.xmr.serialize import xmrserialize as xms + +from apps.monero.xmr.serialize.int_serialize import ( + dump_uint, + dump_uvarint, + load_uint, + load_uvarint, +) from apps.monero.xmr.serialize.readwriter import MemoryReaderWriter from apps.monero.xmr.serialize_messages.base import ECPoint from apps.monero.xmr.serialize_messages.tx_prefix import ( - TxinToKey, TxinGen, + TxinToKey, TxInV, TxOut, TxoutToKey, @@ -84,8 +89,8 @@ def test_varint(self): for test_num in test_nums: writer = MemoryReaderWriter() - xms.dump_uvarint(writer, test_num) - test_deser = xms.load_uvarint(MemoryReaderWriter(writer.get_buffer())) + dump_uvarint(writer, test_num) + test_deser = load_uvarint(MemoryReaderWriter(writer.get_buffer())) self.assertEqual(test_num, test_deser) @@ -97,35 +102,12 @@ def test_ecpoint(self): ec_data = bytearray(range(32)) writer = MemoryReaderWriter() - xms.dump_blob(writer, ec_data, ECPoint) + ECPoint.dump(writer, ec_data) self.assertTrue(len(writer.get_buffer()), ECPoint.SIZE) - test_deser = xms.load_blob( - MemoryReaderWriter(writer.get_buffer()), ECPoint - ) + test_deser = ECPoint.load(MemoryReaderWriter(writer.get_buffer())) self.assertEqual(ec_data, test_deser) - def test_ecpoint_obj(self): - """ - EC point into - :return: - """ - ec_data = bytearray(list(range(32))) - ec_point = ECPoint() - ec_point.data = ec_data - writer = MemoryReaderWriter() - - xms.dump_blob(writer, ec_point, ECPoint) - self.assertTrue(len(writer.get_buffer()), ECPoint.SIZE) - - ec_point2 = ECPoint() - test_deser = xms.load_blob( - MemoryReaderWriter(writer.get_buffer()), ECPoint, elem=ec_point2 - ) - - self.assertEqual(ec_data, ec_point2.data) - self.assertEqual(ec_point, ec_point2) - def test_simple_msg(self): """ TxinGen @@ -134,30 +116,10 @@ def test_simple_msg(self): msg = TxinGen(height=42) writer = MemoryReaderWriter() - ar1 = xms.Archive(writer, True) - ar1.message(msg) - - ar2 = xms.Archive(MemoryReaderWriter(writer.get_buffer()), False) - test_deser = ar2.message(None, msg_type=TxinGen) - self.assertEqual(msg.height, test_deser.height) - - def test_simple_msg_into(self): - """ - TxinGen - :return: - """ - msg = TxinGen(height=42) - - writer = MemoryReaderWriter() - ar1 = xms.Archive(writer, True) - ar1.message(msg) + TxinGen.dump(writer, msg) + test_deser = TxinGen.load(MemoryReaderWriter(writer.get_buffer())) - msg2 = TxinGen() - ar2 = xms.Archive(MemoryReaderWriter(writer.get_buffer()), False) - test_deser = ar2.message(msg2, TxinGen) self.assertEqual(msg.height, test_deser.height) - self.assertEqual(msg.height, msg2.height) - self.assertEqual(msg2, test_deser) def test_txin_to_key(self): """ @@ -169,11 +131,9 @@ def test_txin_to_key(self): ) writer = MemoryReaderWriter() - ar1 = xms.Archive(writer, True) - ar1.message(msg) + TxinToKey.dump(writer, msg) + test_deser = TxinToKey.load(MemoryReaderWriter(writer.get_buffer())) - ar2 = xms.Archive(MemoryReaderWriter(writer.get_buffer()), False) - test_deser = ar2.message(None, TxinToKey) self.assertEqual(msg.amount, test_deser.amount) self.assertEqual(msg, test_deser) @@ -185,19 +145,13 @@ def test_txin_variant(self): msg1 = TxinToKey( amount=123, key_offsets=[1, 2, 3, 2 ** 76], k_image=bytearray(range(32)) ) - msg = TxInV() - msg.set_variant("txin_to_key", msg1) writer = MemoryReaderWriter() - ar1 = xms.Archive(writer, True) - ar1.variant(msg) + TxInV.dump(writer, msg1) + test_deser = TxInV.load(MemoryReaderWriter(writer.get_buffer())) - ar2 = xms.Archive(MemoryReaderWriter(writer.get_buffer()), False) - test_deser = ar2.variant(None, TxInV, wrapped=True) - self.assertEqual(test_deser.__class__, TxInV) - self.assertEqual(msg, test_deser) - self.assertEqual(msg.variant_elem, test_deser.variant_elem) - self.assertEqual(msg.variant_elem_type, test_deser.variant_elem_type) + self.assertEqual(test_deser.__class__, TxinToKey) + self.assertEqual(msg1, test_deser) if __name__ == "__main__":