88from dataclasses import Field , dataclass , fields
99from datetime import datetime
1010from decimal import Decimal
11+ from functools import wraps
1112from inspect import isclass
1213from typing import Any , Callable , List , Type , TypeVar , Union , get_type_hints
1314
@@ -105,6 +106,31 @@ class RawCBOR:
105106`Cbor2 encoder <https://cbor2.readthedocs.io/en/latest/modules/encoder.html>`_ directly.
106107"""
107108
109+
110+ def limit_primitive_type (* allowed_types ):
111+ """
112+ A helper function to validate primitive type given to from_primitive class methods
113+
114+ Not exposed to public by intention.
115+ """
116+
117+ def decorator (func ):
118+ @wraps (func )
119+ def wrapper (cls , value : Primitive ):
120+ if not isinstance (value , allowed_types ):
121+ allowed_types_str = [
122+ allowed_type .__name__ for allowed_type in allowed_types
123+ ]
124+ raise DeserializeException (
125+ f"{ allowed_types_str } typed value is required for deserialization. Got { type (value )} : { value } "
126+ )
127+ return func (cls , value )
128+
129+ return wrapper
130+
131+ return decorator
132+
133+
108134CBORBase = TypeVar ("CBORBase" , bound = "CBORSerializable" )
109135
110136
@@ -245,7 +271,7 @@ def to_validated_primitive(self) -> Primitive:
245271 return self .to_primitive ()
246272
247273 @classmethod
248- def from_primitive (cls : Type [CBORBase ], value : Primitive ) -> CBORBase :
274+ def from_primitive (cls : Type [CBORBase ], value : Any ) -> CBORBase :
249275 """Turn a CBOR primitive to its original class type.
250276
251277 Args:
@@ -407,7 +433,7 @@ def _restore_dataclass_field(
407433 elif t in PRIMITIVE_TYPES and isinstance (v , t ):
408434 return v
409435 raise DeserializeException (
410- f"Cannot deserialize object: \n { str ( v ) } \n in any valid type from { t_args } ."
436+ f"Cannot deserialize object: \n { v } \n in any valid type from { t_args } ."
411437 )
412438 return v
413439
@@ -494,7 +520,8 @@ def to_shallow_primitive(self) -> List[Primitive]:
494520 return primitives
495521
496522 @classmethod
497- def from_primitive (cls : Type [ArrayBase ], values : Primitive ) -> ArrayBase :
523+ @limit_primitive_type (list )
524+ def from_primitive (cls : Type [ArrayBase ], values : list ) -> ArrayBase :
498525 """Restore a primitive value to its original class type.
499526
500527 Args:
@@ -508,10 +535,6 @@ def from_primitive(cls: Type[ArrayBase], values: Primitive) -> ArrayBase:
508535 DeserializeException: When the object could not be restored from primitives.
509536 """
510537 all_fields = [f for f in fields (cls ) if f .init ]
511- if type (values ) != list :
512- raise DeserializeException (
513- f"Expect input value to be a list, got a { type (values )} instead."
514- )
515538
516539 restored_vals = []
517540 type_hints = get_type_hints (cls )
@@ -606,7 +629,8 @@ def to_shallow_primitive(self) -> Primitive:
606629 return primitives
607630
608631 @classmethod
609- def from_primitive (cls : Type [MapBase ], values : Primitive ) -> MapBase :
632+ @limit_primitive_type (dict )
633+ def from_primitive (cls : Type [MapBase ], values : dict ) -> MapBase :
610634 """Restore a primitive value to its original class type.
611635
612636 Args:
@@ -620,10 +644,6 @@ def from_primitive(cls: Type[MapBase], values: Primitive) -> MapBase:
620644 :class:`pycardano.exception.DeserializeException`: When the object could not be restored from primitives.
621645 """
622646 all_fields = {f .metadata .get ("key" , f .name ): f for f in fields (cls ) if f .init }
623- if type (values ) != dict :
624- raise DeserializeException (
625- f"Expect input value to be a dict, got a { type (values )} instead."
626- )
627647
628648 kwargs = {}
629649 type_hints = get_type_hints (cls )
@@ -725,7 +745,8 @@ def _get_sortable_val(key):
725745 return dict (sorted (self .data .items (), key = lambda x : _get_sortable_val (x [0 ])))
726746
727747 @classmethod
728- def from_primitive (cls : Type [DictBase ], value : Primitive ) -> DictBase :
748+ @limit_primitive_type (dict )
749+ def from_primitive (cls : Type [DictBase ], value : dict ) -> DictBase :
729750 """Restore a primitive value to its original class type.
730751
731752 Args:
@@ -739,11 +760,7 @@ def from_primitive(cls: Type[DictBase], value: Primitive) -> DictBase:
739760 DeserializeException: When the object could not be restored from primitives.
740761 """
741762 if not value :
742- raise DeserializeException (f"Cannot accept empty value { str (value )} ." )
743- if not isinstance (value , dict ):
744- raise DeserializeException (
745- f"A dictionary value is required for deserialization: { str (value )} "
746- )
763+ raise DeserializeException (f"Cannot accept empty value { value } ." )
747764
748765 restored = cls ()
749766 for k , v in value .items ():
0 commit comments