diff --git a/README.md b/README.md index 12fcee2..8b6acac 100644 --- a/README.md +++ b/README.md @@ -443,18 +443,15 @@ True ``` -### Recursive schema +### Recursive / nested schema -There is no syntax to have a recursive schema. The best way to do it is to have a wrapper like this: +You can use `voluptuous.Self` to define a nested schema: ```pycon ->>> from voluptuous import Schema, Any ->>> def s2(v): -... return s1(v) -... ->>> s1 = Schema({"key": Any(s2, "value")}) ->>> s1({"key": {"key": "value"}}) -{'key': {'key': 'value'}} +>>> from voluptuous import Schema, Self +>>> recursive = Schema({"more": Self, "value": int}) +>>> recursive({"more": {"value": 42}, "value": 41}) == {'more': {'value': 42}, 'value': 41} +True ``` diff --git a/voluptuous/schema_builder.py b/voluptuous/schema_builder.py index 724c504..dff3e55 100644 --- a/voluptuous/schema_builder.py +++ b/voluptuous/schema_builder.py @@ -122,6 +122,10 @@ def __repr__(self): UNDEFINED = Undefined() +def Self(): + raise er.SchemaError('"Self" should never be called') + + def default_factory(value): if value is UNDEFINED or callable(value): return value @@ -270,6 +274,10 @@ def __call__(self, data): def _compile(self, schema): if schema is Extra: return lambda _, v: v + if schema is Self: + return lambda p, v: self._compiled(p, v) + elif hasattr(schema, "__voluptuous_compile__"): + return schema.__voluptuous_compile__(self) if isinstance(schema, Object): return self._compile_object(schema) if isinstance(schema, collections.Mapping): diff --git a/voluptuous/tests/tests.py b/voluptuous/tests/tests.py index ee5f2fd..98a82ca 100644 --- a/voluptuous/tests/tests.py +++ b/voluptuous/tests/tests.py @@ -10,7 +10,8 @@ Url, MultipleInvalid, LiteralInvalid, TypeInvalid, NotIn, Match, Email, Replace, Range, Coerce, All, Any, Length, FqdnUrl, ALLOW_EXTRA, PREVENT_EXTRA, validate, ExactSequence, Equal, Unordered, Number, Maybe, Datetime, Date, - Contains, Marker, IsDir, IsFile, PathExists, SomeOf, TooManyValid, raises) + Contains, Marker, IsDir, IsFile, PathExists, SomeOf, TooManyValid, Self, + raises) from voluptuous.humanize import humanize_error from voluptuous.util import u @@ -1065,6 +1066,74 @@ def test_SomeOf_max_validation(): validator('Aa1') +def test_self_validation(): + schema = Schema({"number": int, + "follow": Self}) + try: + schema({"number": "abc"}) + except MultipleInvalid: + pass + else: + assert False, "Did not raise Invalid" + try: + schema({"follow": {"number": '123456.712'}}) + except MultipleInvalid: + pass + else: + assert False, "Did not raise Invalid" + schema({"follow": {"number": 123456}}) + schema({"follow": {"follow": {"number": 123456}}}) + + +def test_self_any(): + schema = Schema({"number": int, + "follow": Any(Self, "stop")}) + try: + schema({"number": "abc"}) + except MultipleInvalid: + pass + else: + assert False, "Did not raise Invalid" + try: + schema({"follow": {"number": '123456.712'}}) + except MultipleInvalid: + pass + else: + assert False, "Did not raise Invalid" + schema({"follow": {"number": 123456}}) + schema({"follow": {"follow": {"number": 123456}}}) + schema({"follow": {"follow": {"number": 123456, "follow": "stop"}}}) + + +def test_self_all(): + schema = Schema({"number": int, + "follow": All(Self, + Schema({"extra_number": int}, + extra=ALLOW_EXTRA))}, + extra=ALLOW_EXTRA) + try: + schema({"number": "abc"}) + except MultipleInvalid: + pass + else: + assert False, "Did not raise Invalid" + try: + schema({"follow": {"number": '123456.712'}}) + except MultipleInvalid: + pass + else: + assert False, "Did not raise Invalid" + schema({"follow": {"number": 123456}}) + schema({"follow": {"follow": {"number": 123456}}}) + schema({"follow": {"number": 123456, "extra_number": 123}}) + try: + schema({"follow": {"number": 123456, "extra_number": "123"}}) + except MultipleInvalid: + pass + else: + assert False, "Did not raise Invalid" + + def test_SomeOf_on_bounds_assertion(): with raises(AssertionError, 'when using "SomeOf" you should specify at least one of min_valid and max_valid'): SomeOf(validators=[]) diff --git a/voluptuous/validators.py b/voluptuous/validators.py index 138941a..af655c3 100644 --- a/voluptuous/validators.py +++ b/voluptuous/validators.py @@ -181,7 +181,40 @@ def Boolean(v): return bool(v) -class Any(object): +class _WithSubValidators(object): + """Base class for validators that use sub-validators. + + Special class to use as a parent class for validators using sub-validators. + This class provides the `__voluptuous_compile__` method so the + sub-validators are compiled by the parent `Schema`. + """ + + def __init__(self, *validators, **kwargs): + self.validators = validators + self.msg = kwargs.pop('msg', None) + + def __voluptuous_compile__(self, schema): + self._compiled = [ + schema._compile(v) + for v in self.validators + ] + return self._run + + def _run(self, path, value): + return self._exec(self._compiled, value, path) + + def __call__(self, v): + return self._exec((Schema(val) for val in self.validators), v) + + def __repr__(self): + return '%s(%s, msg=%r)' % ( + self.__class__.__name__, + ", ".join(repr(v) for v in self.validators), + self.msg + ) + + +class Any(_WithSubValidators): """Use the first validated value. :param msg: Message to deliver to user if validation fails. @@ -206,16 +239,14 @@ class Any(object): ... validate(4) """ - def __init__(self, *validators, **kwargs): - self.validators = validators - self.msg = kwargs.pop('msg', None) - self._schemas = [Schema(val, **kwargs) for val in validators] - - def __call__(self, v): + def _exec(self, funcs, v, path=None): error = None - for schema in self._schemas: + for func in funcs: try: - return schema(v) + if path is None: + return func(v) + else: + return func(path, v) except Invalid as e: if error is None or len(e.path) > len(error.path): error = e @@ -224,15 +255,12 @@ def __call__(self, v): raise error if self.msg is None else AnyInvalid(self.msg) raise AnyInvalid(self.msg or 'no valid value found') - def __repr__(self): - return 'Any([%s])' % (", ".join(repr(v) for v in self.validators)) - # Convenience alias Or = Any -class All(object): +class All(_WithSubValidators): """Value must pass all validators. The output of each validator is passed as input to the next. @@ -245,25 +273,17 @@ class All(object): 10 """ - def __init__(self, *validators, **kwargs): - self.validators = validators - self.msg = kwargs.pop('msg', None) - self._schemas = [Schema(val, **kwargs) for val in validators] - - def __call__(self, v): + def _exec(self, funcs, v, path=None): try: - for schema in self._schemas: - v = schema(v) + for func in funcs: + if path is None: + v = func(v) + else: + v = func(path, v) except Invalid as e: raise e if self.msg is None else AllInvalid(self.msg) return v - def __repr__(self): - return 'All(%s, msg=%r)' % ( - ", ".join(repr(v) for v in self.validators), - self.msg - ) - # Convenience alias And = All @@ -936,7 +956,7 @@ def _get_precision_scale(self, number): return (len(decimal_num.as_tuple().digits), -(decimal_num.as_tuple().exponent), decimal_num) -class SomeOf(object): +class SomeOf(_WithSubValidators): """Value must pass at least some validations, determined by the given parameter. Optionally, number of passed validations can be capped. @@ -965,19 +985,21 @@ def __init__(self, validators, min_valid=None, max_valid=None, **kwargs): 'when using "%s" you should specify at least one of min_valid and max_valid' % (type(self).__name__,) self.min_valid = min_valid or 0 self.max_valid = max_valid or len(validators) - self.validators = validators - self.msg = kwargs.pop('msg', None) - self._schemas = [Schema(val, **kwargs) for val in validators] + super(SomeOf, self).__init__(*validators, **kwargs) - def __call__(self, v): + def _exec(self, funcs, v, path=None): errors = [] - for schema in self._schemas: + funcs = list(funcs) + for func in funcs: try: - v = schema(v) + if path is None: + v = func(v) + else: + v = func(path, v) except Invalid as e: errors.append(e) - passed_count = len(self._schemas) - len(errors) + passed_count = len(funcs) - len(errors) if self.min_valid <= passed_count <= self.max_valid: return v