Skip to content

Commit

Permalink
add to_instances helper method to scripted Instances class
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #2232

Reviewed By: theschnitz

Differential Revision: D24753888

fbshipit-source-id: 424bfc7ab0cf085e99333dbb9475d9cb32cd567c
  • Loading branch information
ppwwyyxx authored and facebook-github-bot committed Nov 5, 2020
1 parent 93f0f36 commit 8e3effc
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 17 deletions.
48 changes: 31 additions & 17 deletions detectron2/export/torchscript_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def _gen_instance_class(fields):

class _FieldType:
def __init__(self, name, type_):
assert isinstance(name, str), f"Field name must be str, got {name}"
self.name = name
self.type_ = type_
self.annotation = f"{type_.__module__}.{type_.__name__}"
Expand All @@ -86,11 +87,13 @@ def indent(level, s):

cls_name = "Instances_patched{}".format(_counter)

field_names = tuple(x.name for x in fields)
lines.append(
f"""
class {cls_name}:
def __init__(self, image_size: Tuple[int, int]):
self.image_size = image_size
self._field_names = {field_names}
"""
)

Expand Down Expand Up @@ -155,23 +158,6 @@ def has(self, name: str) -> bool:
"""
)

# support an additional method `from_instances` to convert from the original Instances class
lines.append(
f"""
@torch.jit.unused
@staticmethod
def from_instances(instances: Instances) -> "{cls_name}":
fields = instances.get_fields()
image_size = instances.image_size
new_instances = {cls_name}(image_size)
for name, val in fields.items():
assert hasattr(new_instances, '_{{}}'.format(name)), \\
"No attribute named {{}} in {cls_name}".format(name)
setattr(new_instances, name, deepcopy(val))
return new_instances
"""
)

# support method `to`
lines.append(
f"""
Expand All @@ -197,6 +183,34 @@ def to(self, device: torch.device) -> "{cls_name}":
return ret
"""
)

# support additional methods `from_instances` and `to_instances` to
# convert from/to the original Instances class
lines.append(
f"""
@torch.jit.unused
@staticmethod
def from_instances(instances: Instances) -> "{cls_name}":
fields = instances.get_fields()
image_size = instances.image_size
new_instances = {cls_name}(image_size)
for name, val in fields.items():
assert hasattr(new_instances, '_{{}}'.format(name)), \\
"No attribute named {{}} in {cls_name}".format(name)
setattr(new_instances, name, deepcopy(val))
return new_instances
@torch.jit.unused
def to_instances(self):
ret = Instances(self.image_size)
for name in self._field_names:
val = getattr(self, "_" + name, None)
if val is not None:
ret.set(name, deepcopy(val))
return ret
"""
)

return cls_name, os.linesep.join(lines)


Expand Down
13 changes: 13 additions & 0 deletions tests/structures/test_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,19 @@ def forward(self, x: Instances):
x.a = box_tensors
script_module(x)

@unittest.skipIf(TORCH_VERSION < (1, 7), "Insufficient pytorch version")
def test_from_to_instances(self):
orig = Instances((30, 30))
orig.proposal_boxes = Boxes(torch.rand(3, 4))

fields = {"proposal_boxes": Boxes, "a": Tensor}
with patch_instances(fields) as NewInstances:
# convert to NewInstances and back
new1 = NewInstances.from_instances(orig)
new2 = new1.to_instances()
self.assertTrue(torch.equal(orig.proposal_boxes.tensor, new1.proposal_boxes.tensor))
self.assertTrue(torch.equal(orig.proposal_boxes.tensor, new2.proposal_boxes.tensor))


if __name__ == "__main__":
unittest.main()

0 comments on commit 8e3effc

Please sign in to comment.