From 6e101fb3f84347df713f8812d3796c2c95749997 Mon Sep 17 00:00:00 2001 From: Benjamin Pritchard Date: Wed, 30 Jun 2021 16:38:56 -0400 Subject: [PATCH] Add plain msgpack serialization (#267) * Add plain msgpack serialization * Add plain msgpack to more tests & clean up to/from file * Apply suggestions from code review Co-authored-by: Lori A. Burns * Fix file opening flags in serialization * For small changes for msgpack(-ext) Co-authored-by: Lori A. Burns --- docs/source/changelog.rst | 4 +- qcelemental/models/basemodels.py | 2 +- qcelemental/models/molecule.py | 19 +++---- qcelemental/tests/addons.py | 7 ++- qcelemental/tests/test_molecule.py | 2 +- qcelemental/util/__init__.py | 2 + qcelemental/util/serialization.py | 83 +++++++++++++++++++++++++++++- 7 files changed, 101 insertions(+), 18 deletions(-) diff --git a/docs/source/changelog.rst b/docs/source/changelog.rst index 4f8e12f9..706714f2 100644 --- a/docs/source/changelog.rst +++ b/docs/source/changelog.rst @@ -14,12 +14,12 @@ Changelog .. +++++++++ -0.21.0 / 2021-MM-DD +0.21.0 / 2021-06-30 ------------------- New Features ++++++++++++ -- (:pr:`xxx`) Serialization learned msgpack mode that, in contrast to msgpack-ext, *doesn't* embed NumPy objects. +- (:pr:`267`) Serialization learned msgpack mode that, in contrast to msgpack-ext, *doesn't* embed NumPy objects. Enhancements ++++++++++++ diff --git a/qcelemental/models/basemodels.py b/qcelemental/models/basemodels.py index 9ad18808..5b96042c 100644 --- a/qcelemental/models/basemodels.py +++ b/qcelemental/models/basemodels.py @@ -61,7 +61,7 @@ def parse_raw(cls, data: Union[bytes, str], *, encoding: str = None) -> "ProtoMo if encoding.endswith(("json", "javascript", "pickle")): return super().parse_raw(data, content_type=encoding) - elif encoding in ["msgpack-ext", "json-ext"]: + elif encoding in ["msgpack-ext", "json-ext", "msgpack"]: obj = deserialize(data, encoding) else: raise TypeError(f"Content type '{encoding}' not understood.") diff --git a/qcelemental/models/molecule.py b/qcelemental/models/molecule.py index 72e3e530..d4d1db8c 100644 --- a/qcelemental/models/molecule.py +++ b/qcelemental/models/molecule.py @@ -40,7 +40,7 @@ ".xyz": "xyz", ".psimol": "psi4", ".psi4": "psi4", - ".msgpack": "msgpack", + ".msgpack": "msgpack-ext", } @@ -943,11 +943,11 @@ def from_file(cls, filename: str, dtype: Optional[str] = None, *, orient: bool = data = infile.read() elif dtype == "numpy": data = np.load(filename) - elif dtype == "json": + elif dtype in ["json", "json-ext"]: with open(filename, "r") as infile: - data = json.load(infile) + data = deserialize(infile.read(), encoding="json-ext") dtype = "dict" - elif dtype == "msgpack": + elif dtype in ["msgpack", "msgpack-ext"]: with open(filename, "rb") as infile_bytes: data = deserialize(infile_bytes.read(), encoding="msgpack-ext") dtype = "dict" @@ -975,23 +975,20 @@ def to_file(self, filename: str, dtype: Optional[str] = None) -> None: else: raise KeyError(f"Could not infer dtype from filename: `{filename}`") - flags = "w" if dtype in ["xyz", "xyz+", "psi4"]: stringified = self.to_string(dtype) - elif dtype in ["json"]: - stringified = self.serialize("json") - elif dtype in ["msgpack", "msgpack-ext"]: - stringified = self.serialize("msgpack-ext") - flags = "wb" + elif dtype in ["json", "json-ext", "msgpack", "msgpack-ext"]: + stringified = self.serialize(dtype) elif dtype in ["numpy"]: elements = np.array(self.atomic_numbers).reshape(-1, 1) npmol = np.hstack((elements, self.geometry * constants.conversion_factor("bohr", "angstroms"))) np.save(filename, npmol) return - else: raise KeyError(f"Dtype `{dtype}` is not valid") + flags = "wb" if dtype.startswith("msgpack") else "w" + with open(filename, flags) as handle: handle.write(stringified) diff --git a/qcelemental/tests/addons.py b/qcelemental/tests/addons.py index fe2d9e4e..75277dce 100644 --- a/qcelemental/tests/addons.py +++ b/qcelemental/tests/addons.py @@ -39,7 +39,12 @@ def internet_connection(): reason="Not detecting module py3Dmol. Install package if necessary and add to envvar PYTHONPATH", ) -serialize_extensions = ["json", "json-ext", pytest.param("msgpack-ext", marks=using_msgpack)] +serialize_extensions = [ + "json", + "json-ext", + pytest.param("msgpack", marks=using_msgpack), + pytest.param("msgpack-ext", marks=using_msgpack), +] @contextmanager diff --git a/qcelemental/tests/test_molecule.py b/qcelemental/tests/test_molecule.py index aa56b0d7..05b2180a 100644 --- a/qcelemental/tests/test_molecule.py +++ b/qcelemental/tests/test_molecule.py @@ -219,7 +219,7 @@ def test_to_from_file_simple(tmp_path, dtype, filext): ) p = tmp_path / ("water." + filext) - benchmol.to_file(p) + benchmol.to_file(p, dtype=dtype) mol = Molecule.from_file(p) diff --git a/qcelemental/util/__init__.py b/qcelemental/util/__init__.py index dbc03580..04ae84cd 100644 --- a/qcelemental/util/__init__.py +++ b/qcelemental/util/__init__.py @@ -27,5 +27,7 @@ jsonext_loads, msgpackext_dumps, msgpackext_loads, + msgpack_dumps, + msgpack_loads, serialize, ) diff --git a/qcelemental/util/serialization.py b/qcelemental/util/serialization.py index ef4b09ca..a8f719e1 100644 --- a/qcelemental/util/serialization.py +++ b/qcelemental/util/serialization.py @@ -238,6 +238,78 @@ def json_loads(data: str) -> Any: return json.loads(data, object_hook=jsonext_decode) +## MSGPack + + +def msgpack_encode(obj: Any) -> Any: + r""" + Encodes an object using pydantic. Converts numpy arrays to plain python lists + + Parameters + ---------- + obj : Any + Any object that can be serialized with pydantic and NumPy encoding techniques. + + Returns + ------- + Any + A msgpack compatible form of the object. + """ + + try: + return pydantic_encoder(obj) + except TypeError: + pass + + if isinstance(obj, np.ndarray): + if obj.shape: + return obj.ravel().tolist() + else: + return obj.tolist() + + return obj + + +def msgpack_dumps(data: Any) -> str: + r"""Safe serialization of a Python object to msgpack binary representation using all known encoders. + For NumPy, converts to lists. + + Parameters + ---------- + data : Any + A encodable python object. + + Returns + ------- + str + A msgpack representation of the data in bytes. + """ + + which_import("msgpack", raise_error=True, raise_msg=_msgpack_which_msg) + + return msgpack.dumps(data, default=msgpack_encode, use_bin_type=True) + + +def msgpack_loads(data: str) -> Any: + r"""Deserializes a msgpack byte representation of known objects into those objects. + + Parameters + ---------- + data : bytes + The serialized msgpack byte array. + + Returns + ------- + Any + The deserialized Python objects. + """ + + which_import("msgpack", raise_error=True, raise_msg=_msgpack_which_msg) + + # Doesn't hurt anything to try to load msgpack-ext as well + return msgpack.loads(data, object_hook=msgpackext_decode, raw=False) + + ## Helper functions @@ -261,6 +333,8 @@ def serialize(data: Any, encoding: str) -> Union[str, bytes]: return json_dumps(data) elif encoding.lower() == "json-ext": return jsonext_dumps(data) + elif encoding.lower() == "msgpack": + return msgpack_dumps(data) elif encoding.lower() == "msgpack-ext": return msgpackext_dumps(data) else: @@ -288,8 +362,13 @@ def deserialize(blob: Union[str, bytes], encoding: str) -> Any: elif encoding.lower() == "json-ext": assert isinstance(blob, (str, bytes)) return jsonext_loads(blob) - elif encoding.lower() in ["msgpack", "msgpack-ext"]: + elif encoding.lower() in ["msgpack"]: + assert isinstance(blob, bytes) + return msgpack_loads(blob) + elif encoding.lower() in ["msgpack-ext"]: assert isinstance(blob, bytes) return msgpackext_loads(blob) else: - raise KeyError(f"Encoding '{encoding}' not understood, valid options: 'json', 'json-ext', 'msgpack-ext'") + raise KeyError( + f"Encoding '{encoding}' not understood, valid options: 'json', 'json-ext', 'msgpack', 'msgpack-ext'" + )