diff --git a/src/validators/config.rs b/src/validators/config.rs index 8ffcfd0f5..a43a7643a 100644 --- a/src/validators/config.rs +++ b/src/validators/config.rs @@ -1,7 +1,8 @@ use std::borrow::Cow; use std::str::FromStr; -use base64::Engine; +use base64::engine::general_purpose::{STANDARD, URL_SAFE}; +use base64::{DecodeError, Engine}; use pyo3::types::{PyDict, PyString}; use pyo3::{intern, prelude::*}; @@ -28,14 +29,18 @@ impl ValBytesMode { pub fn deserialize_string<'py>(self, s: &str) -> Result, ErrorType> { match self.ser { BytesMode::Utf8 => Ok(EitherBytes::Cow(Cow::Borrowed(s.as_bytes()))), - BytesMode::Base64 => match base64::engine::general_purpose::URL_SAFE.decode(s) { - Ok(bytes) => Ok(EitherBytes::from(bytes)), - Err(err) => Err(ErrorType::BytesInvalidEncoding { + BytesMode::Base64 => URL_SAFE + .decode(s) + .or_else(|err| match err { + DecodeError::InvalidByte(_, b'/' | b'+') => STANDARD.decode(s), + _ => Err(err), + }) + .map(EitherBytes::from) + .map_err(|err| ErrorType::BytesInvalidEncoding { encoding: "base64".to_string(), encoding_error: err.to_string(), context: None, }), - }, BytesMode::Hex => match hex::decode(s) { Ok(vec) => Ok(EitherBytes::from(vec)), Err(err) => Err(ErrorType::BytesInvalidEncoding { diff --git a/tests/test_json.py b/tests/test_json.py index 03855ae67..f22850704 100644 --- a/tests/test_json.py +++ b/tests/test_json.py @@ -379,19 +379,26 @@ def test_partial_parse(): def test_json_bytes_base64_round_trip(): - data = b'hello' - encoded = b'"aGVsbG8="' - assert to_json(data, bytes_mode='base64') == encoded + data = b'\xd8\x07\xc1Tx$\x91F%\xf3\xf3I\xca\xd8@\x0c\xee\xc3\xab\xff\x7f\xd3\xcd\xcd\xf9\xc2\x10\xe4\xa1\xb01e' + encoded_std = b'"2AfBVHgkkUYl8/NJythADO7Dq/9/083N+cIQ5KGwMWU="' + encoded_url = b'"2AfBVHgkkUYl8_NJythADO7Dq_9_083N-cIQ5KGwMWU="' + assert to_json(data, bytes_mode='base64') == encoded_url v = SchemaValidator({'type': 'bytes'}, {'val_json_bytes': 'base64'}) - assert v.validate_json(encoded) == data + assert v.validate_json(encoded_url) == data + assert v.validate_json(encoded_std) == data + + with pytest.raises(ValidationError) as exc: + v.validate_json('"wrong!"') + [details] = exc.value.errors() + assert details['type'] == 'bytes_invalid_encoding' - assert to_json({'key': data}, bytes_mode='base64') == b'{"key":"aGVsbG8="}' + assert to_json({'key': data}, bytes_mode='base64') == b'{"key":' + encoded_url + b'}' v = SchemaValidator( {'type': 'dict', 'keys_schema': {'type': 'str'}, 'values_schema': {'type': 'bytes'}}, {'val_json_bytes': 'base64'}, ) - assert v.validate_json('{"key":"aGVsbG8="}') == {'key': data} + assert v.validate_json(b'{"key":' + encoded_url + b'}') == {'key': data} def test_json_bytes_base64_invalid():