diff --git a/bindings/pydrake/common/schema_py.cc b/bindings/pydrake/common/schema_py.cc index b1b89eb74e87..b461be5d3e8e 100644 --- a/bindings/pydrake/common/schema_py.cc +++ b/bindings/pydrake/common/schema_py.cc @@ -2,6 +2,7 @@ #include #include "pybind11/eigen.h" +#include "pybind11/eval.h" #include "pybind11/pybind11.h" #include "pybind11/stl.h" @@ -299,6 +300,47 @@ PYBIND11_MODULE(schema, m) { DefAttributesUsingSerialize(&cls, cls_doc); DefReprUsingSerialize(&cls); DefCopyAndDeepCopy(&cls); + + // To support the atypical C++ implementation of Transform::Serialize, we + // need to support attribute operations on Rotation that should actually + // apply to the Rotation::value member field instead. We'll achieve that by + // adding special-cases to getattr and setattr. + cls.def("__getattr__", [](const Class& self, py::str name) -> py::object { + if (std::holds_alternative(self.value)) { + const std::string name_cxx = name; + if (name_cxx == "deg") { + py::object self_py = py::cast(self, py_rvp::reference); + return self_py.attr("value").attr(name); + } + } + if (std::holds_alternative(self.value)) { + const std::string name_cxx = name; + if ((name_cxx == "angle_deg") || (name_cxx == "axis")) { + py::object self_py = py::cast(self, py_rvp::reference); + return self_py.attr("value").attr(name); + } + } + return py::eval("object.__getattr__")(self, name); + }); + cls.def("__setattr__", [](Class& self, py::str name, py::object value) { + if (std::holds_alternative(self.value)) { + const std::string name_cxx = name; + if (name_cxx == "deg") { + py::object self_py = py::cast(self, py_rvp::reference); + self_py.attr("value").attr(name) = value; + return; + } + } + if (std::holds_alternative(self.value)) { + const std::string name_cxx = name; + if ((name_cxx == "angle_deg") || (name_cxx == "axis")) { + py::object self_py = py::cast(self, py_rvp::reference); + self_py.attr("value").attr(name) = value; + return; + } + } + py::eval("object.__setattr__")(self, name, value); + }); } // Bindings for transform.h. @@ -326,8 +368,31 @@ PYBIND11_MODULE(schema, m) { // The Transform::Serialize does something sketchy for the "rotation" field. // We'll undo that damage for the attribute getter and setter functions, but // notably we must leave the __fields__ manifest unchanged to match the C++ - // serialization convention. - cls.def_readwrite("rotation", &Class::rotation, cls_doc.rotation.doc); + // serialization convention and the setter needs to accept either a Rotation + // (the actual type of the property) or any of the allowed Rotation::Variant + // types (which will occur during YAML deserialization). + using RotationOrNestedValue = std::variant; + static_assert( + std::variant_size_v == + 1 /* for Rotation */ + std::variant_size_v); + cls.def_property( + "rotation", + // The getter is just the usual, no special magic. + [](const Class& self) { return &self.rotation; }, + // The setter accepts a more generous allowed set of argument types. + [](Class& self, RotationOrNestedValue value_variant) { + std::visit( + [&self](const T& new_value) { + if constexpr (std::is_same_v) { + self.rotation = new_value; + } else { + self.rotation.value = new_value; + } + }, + value_variant); + }, + py_rvp::reference_internal, cls_doc.rotation.doc); DefReprUsingSerialize(&cls); DefCopyAndDeepCopy(&cls); } diff --git a/bindings/pydrake/common/test/schema_serialization_test.py b/bindings/pydrake/common/test/schema_serialization_test.py index 3aa812d9d97f..a7e8a057476f 100644 --- a/bindings/pydrake/common/test/schema_serialization_test.py +++ b/bindings/pydrake/common/test/schema_serialization_test.py @@ -5,6 +5,7 @@ import unittest import numpy as np +from numpy.testing import assert_allclose import pydrake.common.schema as mut from pydrake.common.yaml import yaml_load_typed @@ -68,7 +69,7 @@ def test_rpy(self): rpy = RollPitchYaw(x.GetDeterministicValue()) rpy_deg = np.array([math.degrees(z) for z in rpy.vector()]) expected = np.array([10.0, 20.0, 30.0]) - np.testing.assert_allclose(rpy_deg, expected) + assert_allclose(rpy_deg, expected) def test_angle_axis(self): data = "value: !AngleAxis { angle_deg: 10.0, axis: [0, 1, 0] }" @@ -76,7 +77,7 @@ def test_angle_axis(self): self.assertTrue(x.IsDeterministic()) aa = x.GetDeterministicValue().ToAngleAxis() self.assertAlmostEqual(math.degrees(aa.angle()), 10.0) - np.testing.assert_allclose(aa.axis(), [0.0, 1.0, 0.0]) + assert_allclose(aa.axis(), [0.0, 1.0, 0.0]) def test_uniform(self): data = "value: !Uniform {}" @@ -95,6 +96,66 @@ def test_rpy_uniform(self): self.assertFalse(x.IsDeterministic()) -# TODO(jwnimmer-tri) Add serialization tests for schema.Translation. For now, -# the weird C++ `MakeNameValue("rotation", &rotation.value)` is incompatible -# with how our Python deserialization works. +class TestTransformSerialization(unittest.TestCase): + """Serialization tests related to schema/transform.h""" + + def test_deterministic(self): + data = dedent(""" + base_frame: foo + translation: [1.0, 2.0, 3.0] + rotation: !Rpy { deg: [10, 20, 30] } + """) + x = yaml_load_typed(schema=mut.Transform, data=data) + self.assertEqual(x.base_frame, "foo") + assert_allclose(x.translation, [1.0, 2.0, 3.0]) + assert_allclose(x.rotation.value.deg, [10, 20, 30]) + + def test_random(self): + data = dedent(""" + base_frame: bar + translation: !UniformVector + min: [1.0, 2.0, 3.0] + max: [4.0, 5.0, 6.0] + rotation: !Uniform {} + """) + x = yaml_load_typed(schema=mut.Transform, data=data) + self.assertEqual(x.base_frame, "bar") + assert_allclose(x.translation.min, [1.0, 2.0, 3.0]) + assert_allclose(x.translation.max, [4.0, 5.0, 6.0]) + self.assertEqual(type(x.rotation.value), mut.Rotation.Uniform) + + def test_random_bounded(self): + data = dedent(""" + base_frame: baz + translation: !UniformVector + min: [1.0, 2.0, 3.0] + max: [4.0, 5.0, 6.0] + rotation: !Rpy + deg: !UniformVector + min: [380, -0.25, -1.0] + max: [400, 0.25, 1.0] + """) + x = yaml_load_typed(schema=mut.Transform, data=data) + self.assertEqual(x.base_frame, "baz") + assert_allclose(x.translation.min, [1.0, 2.0, 3.0]) + assert_allclose(x.translation.max, [4.0, 5.0, 6.0]) + assert_allclose(x.rotation.value.deg.min, [380, -0.25, -1.0]) + assert_allclose(x.rotation.value.deg.max, [400, 0.25, 1.0]) + + def test_random_angle_axis(self): + data = dedent(""" + base_frame: quux + rotation: !AngleAxis + angle_deg: !Uniform + min: 10 + max: 20 + axis: !UniformVector + min: [1, 2, 3] + max: [4, 5, 6] + """) + x = yaml_load_typed(schema=mut.Transform, data=data) + self.assertEqual(x.base_frame, "quux") + self.assertEqual(x.rotation.value.angle_deg.min, 10) + self.assertEqual(x.rotation.value.angle_deg.max, 20) + assert_allclose(x.rotation.value.axis.min, [1, 2, 3]) + assert_allclose(x.rotation.value.axis.max, [4, 5, 6]) diff --git a/bindings/pydrake/common/test/schema_test.py b/bindings/pydrake/common/test/schema_test.py index 8f21e903c9e1..9d795142d09c 100644 --- a/bindings/pydrake/common/test/schema_test.py +++ b/bindings/pydrake/common/test/schema_test.py @@ -211,6 +211,8 @@ def test_transform(self): dut.set_rotation_rpy_deg([0.1, 0.2, 0.3]) np.testing.assert_equal(dut.rotation.value.deg, [0.1, 0.2, 0.3]) + dut.rotation.value.deg = [0.4, 0.5, 0.6] + np.testing.assert_equal(dut.rotation.value.deg, [0.4, 0.5, 0.6]) # Attributes. self.assertEqual(mut.Transform(base_frame="base").base_frame, "base") diff --git a/bindings/pydrake/common/yaml.py b/bindings/pydrake/common/yaml.py index 694a202dec27..ee9b544ccf41 100644 --- a/bindings/pydrake/common/yaml.py +++ b/bindings/pydrake/common/yaml.py @@ -211,6 +211,7 @@ def _merge_yaml_dict_item_into_target(*, options, name, yaml_value, writing the result to the field named `name` of the given `target` object. """ # The target can be either a dictionary or a dataclass. + assert target is not None if isinstance(target, collections.abc.Mapping): old_value = target[name] setter = functools.partial(target.__setitem__, name) @@ -372,6 +373,7 @@ def _merge_yaml_dict_into_target(*, options, yaml_dict, raw strings, dictionaries, lists, etc.). """ assert isinstance(yaml_dict, collections.abc.Mapping), yaml_dict + assert target is not None static_field_map = _enumerate_field_types(target_schema) schema_names = list(static_field_map.keys()) schema_optionals = set([