Skip to content

Commit

Permalink
Add plain msgpack serialization (#267)
Browse files Browse the repository at this point in the history
* Add plain msgpack serialization

* Add plain msgpack to more tests & clean up to/from file

* Apply suggestions from code review

Co-authored-by: Lori A. Burns <[email protected]>

* Fix file opening flags in serialization

* For small changes for msgpack(-ext)

Co-authored-by: Lori A. Burns <[email protected]>
  • Loading branch information
bennybp and loriab authored Jun 30, 2021
1 parent 5e413e6 commit 6e101fb
Show file tree
Hide file tree
Showing 7 changed files with 101 additions and 18 deletions.
4 changes: 2 additions & 2 deletions docs/source/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ Changelog
.. +++++++++
0.21.0 / 2021-MM-DD
0.21.0 / 2021-06-30
-------------------

New Features
++++++++++++
- (:pr:`xxx`) Serialization learned msgpack mode that, in contrast to msgpack-ext, *doesn't* embed NumPy objects.
- (:pr:`267`) Serialization learned msgpack mode that, in contrast to msgpack-ext, *doesn't* embed NumPy objects.

Enhancements
++++++++++++
Expand Down
2 changes: 1 addition & 1 deletion qcelemental/models/basemodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def parse_raw(cls, data: Union[bytes, str], *, encoding: str = None) -> "ProtoMo

if encoding.endswith(("json", "javascript", "pickle")):
return super().parse_raw(data, content_type=encoding)
elif encoding in ["msgpack-ext", "json-ext"]:
elif encoding in ["msgpack-ext", "json-ext", "msgpack"]:
obj = deserialize(data, encoding)
else:
raise TypeError(f"Content type '{encoding}' not understood.")
Expand Down
19 changes: 8 additions & 11 deletions qcelemental/models/molecule.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
".xyz": "xyz",
".psimol": "psi4",
".psi4": "psi4",
".msgpack": "msgpack",
".msgpack": "msgpack-ext",
}


Expand Down Expand Up @@ -943,11 +943,11 @@ def from_file(cls, filename: str, dtype: Optional[str] = None, *, orient: bool =
data = infile.read()
elif dtype == "numpy":
data = np.load(filename)
elif dtype == "json":
elif dtype in ["json", "json-ext"]:
with open(filename, "r") as infile:
data = json.load(infile)
data = deserialize(infile.read(), encoding="json-ext")
dtype = "dict"
elif dtype == "msgpack":
elif dtype in ["msgpack", "msgpack-ext"]:
with open(filename, "rb") as infile_bytes:
data = deserialize(infile_bytes.read(), encoding="msgpack-ext")
dtype = "dict"
Expand Down Expand Up @@ -975,23 +975,20 @@ def to_file(self, filename: str, dtype: Optional[str] = None) -> None:
else:
raise KeyError(f"Could not infer dtype from filename: `{filename}`")

flags = "w"
if dtype in ["xyz", "xyz+", "psi4"]:
stringified = self.to_string(dtype)
elif dtype in ["json"]:
stringified = self.serialize("json")
elif dtype in ["msgpack", "msgpack-ext"]:
stringified = self.serialize("msgpack-ext")
flags = "wb"
elif dtype in ["json", "json-ext", "msgpack", "msgpack-ext"]:
stringified = self.serialize(dtype)
elif dtype in ["numpy"]:
elements = np.array(self.atomic_numbers).reshape(-1, 1)
npmol = np.hstack((elements, self.geometry * constants.conversion_factor("bohr", "angstroms")))
np.save(filename, npmol)
return

else:
raise KeyError(f"Dtype `{dtype}` is not valid")

flags = "wb" if dtype.startswith("msgpack") else "w"

with open(filename, flags) as handle:
handle.write(stringified)

Expand Down
7 changes: 6 additions & 1 deletion qcelemental/tests/addons.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,12 @@ def internet_connection():
reason="Not detecting module py3Dmol. Install package if necessary and add to envvar PYTHONPATH",
)

serialize_extensions = ["json", "json-ext", pytest.param("msgpack-ext", marks=using_msgpack)]
serialize_extensions = [
"json",
"json-ext",
pytest.param("msgpack", marks=using_msgpack),
pytest.param("msgpack-ext", marks=using_msgpack),
]


@contextmanager
Expand Down
2 changes: 1 addition & 1 deletion qcelemental/tests/test_molecule.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def test_to_from_file_simple(tmp_path, dtype, filext):
)

p = tmp_path / ("water." + filext)
benchmol.to_file(p)
benchmol.to_file(p, dtype=dtype)

mol = Molecule.from_file(p)

Expand Down
2 changes: 2 additions & 0 deletions qcelemental/util/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,7 @@
jsonext_loads,
msgpackext_dumps,
msgpackext_loads,
msgpack_dumps,
msgpack_loads,
serialize,
)
83 changes: 81 additions & 2 deletions qcelemental/util/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,78 @@ def json_loads(data: str) -> Any:
return json.loads(data, object_hook=jsonext_decode)


## MSGPack


def msgpack_encode(obj: Any) -> Any:
r"""
Encodes an object using pydantic. Converts numpy arrays to plain python lists
Parameters
----------
obj : Any
Any object that can be serialized with pydantic and NumPy encoding techniques.
Returns
-------
Any
A msgpack compatible form of the object.
"""

try:
return pydantic_encoder(obj)
except TypeError:
pass

if isinstance(obj, np.ndarray):
if obj.shape:
return obj.ravel().tolist()
else:
return obj.tolist()

return obj


def msgpack_dumps(data: Any) -> str:
r"""Safe serialization of a Python object to msgpack binary representation using all known encoders.
For NumPy, converts to lists.
Parameters
----------
data : Any
A encodable python object.
Returns
-------
str
A msgpack representation of the data in bytes.
"""

which_import("msgpack", raise_error=True, raise_msg=_msgpack_which_msg)

return msgpack.dumps(data, default=msgpack_encode, use_bin_type=True)


def msgpack_loads(data: str) -> Any:
r"""Deserializes a msgpack byte representation of known objects into those objects.
Parameters
----------
data : bytes
The serialized msgpack byte array.
Returns
-------
Any
The deserialized Python objects.
"""

which_import("msgpack", raise_error=True, raise_msg=_msgpack_which_msg)

# Doesn't hurt anything to try to load msgpack-ext as well
return msgpack.loads(data, object_hook=msgpackext_decode, raw=False)


## Helper functions


Expand All @@ -261,6 +333,8 @@ def serialize(data: Any, encoding: str) -> Union[str, bytes]:
return json_dumps(data)
elif encoding.lower() == "json-ext":
return jsonext_dumps(data)
elif encoding.lower() == "msgpack":
return msgpack_dumps(data)
elif encoding.lower() == "msgpack-ext":
return msgpackext_dumps(data)
else:
Expand Down Expand Up @@ -288,8 +362,13 @@ def deserialize(blob: Union[str, bytes], encoding: str) -> Any:
elif encoding.lower() == "json-ext":
assert isinstance(blob, (str, bytes))
return jsonext_loads(blob)
elif encoding.lower() in ["msgpack", "msgpack-ext"]:
elif encoding.lower() in ["msgpack"]:
assert isinstance(blob, bytes)
return msgpack_loads(blob)
elif encoding.lower() in ["msgpack-ext"]:
assert isinstance(blob, bytes)
return msgpackext_loads(blob)
else:
raise KeyError(f"Encoding '{encoding}' not understood, valid options: 'json', 'json-ext', 'msgpack-ext'")
raise KeyError(
f"Encoding '{encoding}' not understood, valid options: 'json', 'json-ext', 'msgpack', 'msgpack-ext'"
)

0 comments on commit 6e101fb

Please sign in to comment.