Skip to content

Commit

Permalink
[pydrake] Add YAML parsing for schema.Transform (#18606)
Browse files Browse the repository at this point in the history
  • Loading branch information
jwnimmer-tri authored Jan 18, 2023
1 parent b3a86c8 commit 4ffc07e
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 7 deletions.
69 changes: 67 additions & 2 deletions bindings/pydrake/common/schema_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <vector>

#include "pybind11/eigen.h"
#include "pybind11/eval.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"

Expand Down Expand Up @@ -299,6 +300,47 @@ PYBIND11_MODULE(schema, m) {
DefAttributesUsingSerialize(&cls, cls_doc);
DefReprUsingSerialize(&cls);
DefCopyAndDeepCopy(&cls);

// To support the atypical C++ implementation of Transform::Serialize, we
// need to support attribute operations on Rotation that should actually
// apply to the Rotation::value member field instead. We'll achieve that by
// adding special-cases to getattr and setattr.
cls.def("__getattr__", [](const Class& self, py::str name) -> py::object {
if (std::holds_alternative<Rotation::Rpy>(self.value)) {
const std::string name_cxx = name;
if (name_cxx == "deg") {
py::object self_py = py::cast(self, py_rvp::reference);
return self_py.attr("value").attr(name);
}
}
if (std::holds_alternative<Rotation::AngleAxis>(self.value)) {
const std::string name_cxx = name;
if ((name_cxx == "angle_deg") || (name_cxx == "axis")) {
py::object self_py = py::cast(self, py_rvp::reference);
return self_py.attr("value").attr(name);
}
}
return py::eval("object.__getattr__")(self, name);
});
cls.def("__setattr__", [](Class& self, py::str name, py::object value) {
if (std::holds_alternative<Rotation::Rpy>(self.value)) {
const std::string name_cxx = name;
if (name_cxx == "deg") {
py::object self_py = py::cast(self, py_rvp::reference);
self_py.attr("value").attr(name) = value;
return;
}
}
if (std::holds_alternative<Rotation::AngleAxis>(self.value)) {
const std::string name_cxx = name;
if ((name_cxx == "angle_deg") || (name_cxx == "axis")) {
py::object self_py = py::cast(self, py_rvp::reference);
self_py.attr("value").attr(name) = value;
return;
}
}
py::eval("object.__setattr__")(self, name, value);
});
}

// Bindings for transform.h.
Expand Down Expand Up @@ -326,8 +368,31 @@ PYBIND11_MODULE(schema, m) {
// The Transform::Serialize does something sketchy for the "rotation" field.
// We'll undo that damage for the attribute getter and setter functions, but
// notably we must leave the __fields__ manifest unchanged to match the C++
// serialization convention.
cls.def_readwrite("rotation", &Class::rotation, cls_doc.rotation.doc);
// serialization convention and the setter needs to accept either a Rotation
// (the actual type of the property) or any of the allowed Rotation::Variant
// types (which will occur during YAML deserialization).
using RotationOrNestedValue = std::variant<Rotation, Rotation::Identity,
Rotation::Rpy, Rotation::AngleAxis, Rotation::Uniform>;
static_assert(
std::variant_size_v<RotationOrNestedValue> ==
1 /* for Rotation */ + std::variant_size_v<Rotation::Variant>);
cls.def_property(
"rotation",
// The getter is just the usual, no special magic.
[](const Class& self) { return &self.rotation; },
// The setter accepts a more generous allowed set of argument types.
[](Class& self, RotationOrNestedValue value_variant) {
std::visit(
[&self]<typename T>(const T& new_value) {
if constexpr (std::is_same_v<T, Rotation>) {
self.rotation = new_value;
} else {
self.rotation.value = new_value;
}
},
value_variant);
},
py_rvp::reference_internal, cls_doc.rotation.doc);
DefReprUsingSerialize(&cls);
DefCopyAndDeepCopy(&cls);
}
Expand Down
71 changes: 66 additions & 5 deletions bindings/pydrake/common/test/schema_serialization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import unittest

import numpy as np
from numpy.testing import assert_allclose

import pydrake.common.schema as mut
from pydrake.common.yaml import yaml_load_typed
Expand Down Expand Up @@ -68,15 +69,15 @@ def test_rpy(self):
rpy = RollPitchYaw(x.GetDeterministicValue())
rpy_deg = np.array([math.degrees(z) for z in rpy.vector()])
expected = np.array([10.0, 20.0, 30.0])
np.testing.assert_allclose(rpy_deg, expected)
assert_allclose(rpy_deg, expected)

def test_angle_axis(self):
data = "value: !AngleAxis { angle_deg: 10.0, axis: [0, 1, 0] }"
x = yaml_load_typed(schema=mut.Rotation, data=data)
self.assertTrue(x.IsDeterministic())
aa = x.GetDeterministicValue().ToAngleAxis()
self.assertAlmostEqual(math.degrees(aa.angle()), 10.0)
np.testing.assert_allclose(aa.axis(), [0.0, 1.0, 0.0])
assert_allclose(aa.axis(), [0.0, 1.0, 0.0])

def test_uniform(self):
data = "value: !Uniform {}"
Expand All @@ -95,6 +96,66 @@ def test_rpy_uniform(self):
self.assertFalse(x.IsDeterministic())


# TODO(jwnimmer-tri) Add serialization tests for schema.Translation. For now,
# the weird C++ `MakeNameValue("rotation", &rotation.value)` is incompatible
# with how our Python deserialization works.
class TestTransformSerialization(unittest.TestCase):
"""Serialization tests related to schema/transform.h"""

def test_deterministic(self):
data = dedent("""
base_frame: foo
translation: [1.0, 2.0, 3.0]
rotation: !Rpy { deg: [10, 20, 30] }
""")
x = yaml_load_typed(schema=mut.Transform, data=data)
self.assertEqual(x.base_frame, "foo")
assert_allclose(x.translation, [1.0, 2.0, 3.0])
assert_allclose(x.rotation.value.deg, [10, 20, 30])

def test_random(self):
data = dedent("""
base_frame: bar
translation: !UniformVector
min: [1.0, 2.0, 3.0]
max: [4.0, 5.0, 6.0]
rotation: !Uniform {}
""")
x = yaml_load_typed(schema=mut.Transform, data=data)
self.assertEqual(x.base_frame, "bar")
assert_allclose(x.translation.min, [1.0, 2.0, 3.0])
assert_allclose(x.translation.max, [4.0, 5.0, 6.0])
self.assertEqual(type(x.rotation.value), mut.Rotation.Uniform)

def test_random_bounded(self):
data = dedent("""
base_frame: baz
translation: !UniformVector
min: [1.0, 2.0, 3.0]
max: [4.0, 5.0, 6.0]
rotation: !Rpy
deg: !UniformVector
min: [380, -0.25, -1.0]
max: [400, 0.25, 1.0]
""")
x = yaml_load_typed(schema=mut.Transform, data=data)
self.assertEqual(x.base_frame, "baz")
assert_allclose(x.translation.min, [1.0, 2.0, 3.0])
assert_allclose(x.translation.max, [4.0, 5.0, 6.0])
assert_allclose(x.rotation.value.deg.min, [380, -0.25, -1.0])
assert_allclose(x.rotation.value.deg.max, [400, 0.25, 1.0])

def test_random_angle_axis(self):
data = dedent("""
base_frame: quux
rotation: !AngleAxis
angle_deg: !Uniform
min: 10
max: 20
axis: !UniformVector
min: [1, 2, 3]
max: [4, 5, 6]
""")
x = yaml_load_typed(schema=mut.Transform, data=data)
self.assertEqual(x.base_frame, "quux")
self.assertEqual(x.rotation.value.angle_deg.min, 10)
self.assertEqual(x.rotation.value.angle_deg.max, 20)
assert_allclose(x.rotation.value.axis.min, [1, 2, 3])
assert_allclose(x.rotation.value.axis.max, [4, 5, 6])
2 changes: 2 additions & 0 deletions bindings/pydrake/common/test/schema_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,8 @@ def test_transform(self):

dut.set_rotation_rpy_deg([0.1, 0.2, 0.3])
np.testing.assert_equal(dut.rotation.value.deg, [0.1, 0.2, 0.3])
dut.rotation.value.deg = [0.4, 0.5, 0.6]
np.testing.assert_equal(dut.rotation.value.deg, [0.4, 0.5, 0.6])

# Attributes.
self.assertEqual(mut.Transform(base_frame="base").base_frame, "base")
2 changes: 2 additions & 0 deletions bindings/pydrake/common/yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ def _merge_yaml_dict_item_into_target(*, options, name, yaml_value,
writing the result to the field named `name` of the given `target` object.
"""
# The target can be either a dictionary or a dataclass.
assert target is not None
if isinstance(target, collections.abc.Mapping):
old_value = target[name]
setter = functools.partial(target.__setitem__, name)
Expand Down Expand Up @@ -372,6 +373,7 @@ def _merge_yaml_dict_into_target(*, options, yaml_dict,
raw strings, dictionaries, lists, etc.).
"""
assert isinstance(yaml_dict, collections.abc.Mapping), yaml_dict
assert target is not None
static_field_map = _enumerate_field_types(target_schema)
schema_names = list(static_field_map.keys())
schema_optionals = set([
Expand Down

0 comments on commit 4ffc07e

Please sign in to comment.