Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/serializers/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ pub(crate) fn infer_to_python_known(
PyList::new(py, items).into_py(py)
}
ObType::Path => value.str()?.into_py(py),
ObType::Pattern => value.getattr(intern!(py, "pattern"))?.into_py(py),
ObType::Unknown => {
if let Some(fallback) = extra.fallback {
let next_value = fallback.call1((value,))?;
Expand Down Expand Up @@ -505,6 +506,16 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
let s = value.str().map_err(py_err_se_err)?.to_str().map_err(py_err_se_err)?;
serializer.serialize_str(s)
}
ObType::Pattern => {
let s = value
.getattr(intern!(value.py(), "pattern"))
.map_err(py_err_se_err)?
.str()
.map_err(py_err_se_err)?
.to_str()
.map_err(py_err_se_err)?;
serializer.serialize_str(s)
}
ObType::Unknown => {
if let Some(fallback) = extra.fallback {
let next_value = fallback.call1((value,)).map_err(py_err_se_err)?;
Expand Down Expand Up @@ -628,6 +639,7 @@ pub(crate) fn infer_json_key_known<'py>(ob_type: ObType, key: &'py PyAny, extra:
infer_json_key(k, extra)
}
ObType::Path => Ok(key.str()?.to_string_lossy()),
ObType::Pattern => Ok(key.getattr(intern!(key.py(), "pattern"))?.str()?.to_string_lossy()),
ObType::Unknown => {
if let Some(fallback) = extra.fallback {
let next_key = fallback.call1((key,))?;
Expand Down
10 changes: 10 additions & 0 deletions src/serializers/ob_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ pub struct ObTypeLookup {
generator_object: PyObject,
// path
path_object: PyObject,
// pattern
pattern_object: PyObject,
// uuid type
uuid_object: PyObject,
}
Expand Down Expand Up @@ -87,6 +89,7 @@ impl ObTypeLookup {
.unwrap()
.to_object(py),
path_object: py.import("pathlib").unwrap().getattr("Path").unwrap().to_object(py),
pattern_object: py.import("re").unwrap().getattr("Pattern").unwrap().to_object(py),
uuid_object: py.import("uuid").unwrap().getattr("UUID").unwrap().to_object(py),
}
}
Expand Down Expand Up @@ -150,6 +153,7 @@ impl ObTypeLookup {
ObType::Enum => self.enum_object.as_ptr() as usize == ob_type,
ObType::Generator => self.generator_object.as_ptr() as usize == ob_type,
ObType::Path => self.path_object.as_ptr() as usize == ob_type,
ObType::Pattern => self.path_object.as_ptr() as usize == ob_type,
ObType::Uuid => self.uuid_object.as_ptr() as usize == ob_type,
ObType::Unknown => false,
};
Expand Down Expand Up @@ -242,6 +246,8 @@ impl ObTypeLookup {
ObType::Generator
} else if ob_type == self.path_object.as_ptr() as usize {
ObType::Path
} else if ob_type == self.pattern_object.as_ptr() as usize {
ObType::Pattern
} else {
// this allows for subtypes of the supported class types,
// if `ob_type` didn't match any member of self, we try again with the next base type pointer
Expand Down Expand Up @@ -319,6 +325,8 @@ impl ObTypeLookup {
ObType::Generator
} else if value.is_instance(self.path_object.as_ref(py)).unwrap_or(false) {
ObType::Path
} else if value.is_instance(self.pattern_object.as_ref(py)).unwrap_or(false) {
ObType::Pattern
} else {
ObType::Unknown
}
Expand Down Expand Up @@ -396,6 +404,8 @@ pub enum ObType {
Generator,
// Path
Path,
//Pattern,
Pattern,
// Uuid
Uuid,
// unknown type
Expand Down
3 changes: 2 additions & 1 deletion tests/serializers/test_any.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import dataclasses
import json
import platform
import re
import sys
from collections import namedtuple
from datetime import date, datetime, time, timedelta, timezone
Expand Down Expand Up @@ -437,7 +438,7 @@ def test_base64():
(lambda: MyEnum.a, {}, b'1'),
(lambda: MyEnum.b, {}, b'"b"'),
(lambda: [MyDataclass(1, 'a', 2), MyModel(a=2, b='b')], {}, b'[{"a":1,"b":"a"},{"a":2,"b":"b"}]'),
# # (lambda: re.compile('^regex$'), b'"^regex$"'),
(lambda: re.compile('^regex$'), {}, b'"^regex$"'),
],
)
def test_encoding(any_serializer, gen_input, kwargs, expected_json):
Expand Down