1111from decimal import Decimal
1212from functools import wraps
1313from inspect import isclass
14- from typing import Any , Callable , List , Optional , Type , TypeVar , Union , get_type_hints
14+ from typing import (
15+ Any ,
16+ Callable ,
17+ ClassVar ,
18+ Dict ,
19+ List ,
20+ Optional ,
21+ Type ,
22+ TypeVar ,
23+ Union ,
24+ get_type_hints ,
25+ )
1526
1627from cbor2 import CBOREncoder , CBORSimpleValue , CBORTag , dumps , loads , undefined
1728from pprintpp import pformat
@@ -254,7 +265,45 @@ def validate(self):
254265 Raises:
255266 InvalidDataException: When the data is invalid.
256267 """
257- pass
268+ type_hints = get_type_hints (self .__class__ )
269+
270+ def _check_recursive (value , type_hint ):
271+ if type_hint is Any :
272+ return True
273+ origin = getattr (type_hint , "__origin__" , None )
274+ if origin is None :
275+ if isinstance (value , CBORSerializable ):
276+ value .validate ()
277+ return isinstance (value , type_hint )
278+ elif origin is ClassVar :
279+ return _check_recursive (value , type_hint .__args__ [0 ])
280+ elif origin is Union :
281+ return any (_check_recursive (value , arg ) for arg in type_hint .__args__ )
282+ elif origin is Dict or isinstance (value , dict ):
283+ key_type , value_type = type_hint .__args__
284+ return all (
285+ _check_recursive (k , key_type ) and _check_recursive (v , value_type )
286+ for k , v in value .items ()
287+ )
288+ elif origin in (list , set , tuple ):
289+ if value is None :
290+ return True
291+ args = type_hint .__args__
292+ if len (args ) == 1 :
293+ return all (_check_recursive (item , args [0 ]) for item in value )
294+ elif len (args ) > 1 :
295+ return all (
296+ _check_recursive (item , arg ) for item , arg in zip (value , args )
297+ )
298+ return True # We don't know how to check this type
299+
300+ for field_name , field_type in type_hints .items ():
301+ field_value = getattr (self , field_name )
302+ if not _check_recursive (field_value , field_type ):
303+ raise TypeError (
304+ f"Field '{ field_name } ' should be of type { field_type } , "
305+ f"got { repr (field_value )} instead."
306+ )
258307
259308 def to_validated_primitive (self ) -> Primitive :
260309 """Convert the instance and its elements to CBOR primitives recursively with data validated by :meth:`validate`
@@ -505,8 +554,8 @@ class ArrayCBORSerializable(CBORSerializable):
505554 >>> t = Test2(c="c", test1=Test1(a="a"))
506555 >>> t
507556 Test2(c='c', test1=Test1(a='a', b=None))
508- >>> cbor_hex = t.to_cbor()
509- >>> cbor_hex
557+ >>> cbor_hex = t.to_cbor() # doctest: +SKIP
558+ >>> cbor_hex # doctest: +SKIP
510559 '826163826161f6'
511560 >>> Test2.from_cbor(cbor_hex) # doctest: +SKIP
512561 Test2(c='c', test1=Test1(a='a', b=None))
@@ -534,8 +583,8 @@ class ArrayCBORSerializable(CBORSerializable):
534583 Test2(c='c', test1=Test1(a='a', b=None))
535584 >>> t.to_primitive() # Notice below that attribute "b" is not included in converted primitive.
536585 ['c', ['a']]
537- >>> cbor_hex = t.to_cbor()
538- >>> cbor_hex
586+ >>> cbor_hex = t.to_cbor() # doctest: +SKIP
587+ >>> cbor_hex # doctest: +SKIP
539588 '826163816161'
540589 >>> Test2.from_cbor(cbor_hex) # doctest: +SKIP
541590 Test2(c='c', test1=Test1(a='a', b=None))
@@ -621,8 +670,8 @@ class MapCBORSerializable(CBORSerializable):
621670 Test2(c=None, test1=Test1(a='a', b=''))
622671 >>> t.to_primitive()
623672 {'c': None, 'test1': {'a': 'a', 'b': ''}}
624- >>> cbor_hex = t.to_cbor()
625- >>> cbor_hex
673+ >>> cbor_hex = t.to_cbor() # doctest: +SKIP
674+ >>> cbor_hex # doctest: +SKIP
626675 'a26163f6657465737431a261616161616260'
627676 >>> Test2.from_cbor(cbor_hex) # doctest: +SKIP
628677 Test2(c=None, test1=Test1(a='a', b=''))
@@ -645,8 +694,8 @@ class MapCBORSerializable(CBORSerializable):
645694 Test2(c=None, test1=Test1(a='a', b=''))
646695 >>> t.to_primitive()
647696 {'1': {'0': 'a', '1': ''}}
648- >>> cbor_hex = t.to_cbor()
649- >>> cbor_hex
697+ >>> cbor_hex = t.to_cbor() # doctest: +SKIP
698+ >>> cbor_hex # doctest: +SKIP
650699 'a16131a261306161613160'
651700 >>> Test2.from_cbor(cbor_hex) # doctest: +SKIP
652701 Test2(c=None, test1=Test1(a='a', b=''))
0 commit comments