From 1ca4d2ab50eb57c46d7467be42f853583742654f Mon Sep 17 00:00:00 2001 From: zhaoguochun1995 Date: Fri, 25 Oct 2024 14:40:49 +0800 Subject: [PATCH] Fixed the issue where autocompare reports input being modified in situations like torch.exp(x, out=x) --- op_tools/op_autocompare_hook.py | 5 ++++- op_tools/pretty_print.py | 6 +++++- op_tools/test/test_tool_with_special_op.py | 5 +++++ op_tools/utils.py | 6 +++++- 4 files changed, 19 insertions(+), 3 deletions(-) diff --git a/op_tools/op_autocompare_hook.py b/op_tools/op_autocompare_hook.py index 3ec9ecb..72b02bd 100644 --- a/op_tools/op_autocompare_hook.py +++ b/op_tools/op_autocompare_hook.py @@ -148,8 +148,11 @@ def run_forward_on_cpu(self): dtype_cast_dict=self.dtype_cast_dict, detach=True, ) + if self.kwargs.get("out", None) is not None and self.kwargs["out"] in self.args and isinstance(self.kwargs["out"], torch.Tensor): + self.kwargs_cpu["out"] = self.args_cpu[self.args.index(self.kwargs["out"])] + # RuntimeError: a leaf Variable that requires grad is being used in an in-place operation. - if (is_inplace_op(self.name) or self.kwargs.get("inplace", False) or is_view_op(self.name)) and self.args[0].requires_grad: + if (is_inplace_op(self.name, *self.args, **self.kwargs) or is_view_op(self.name)) and self.args[0].requires_grad: args_cpu = [item for item in self.args_cpu] args_cpu[0] = args_cpu[0].clone() self.args_cpu = tuple(args_cpu) diff --git a/op_tools/pretty_print.py b/op_tools/pretty_print.py index 6019c8a..971f41d 100644 --- a/op_tools/pretty_print.py +++ b/op_tools/pretty_print.py @@ -43,6 +43,10 @@ def packect_data_to_dict_list(op_name, inputs_dict): elif isinstance(arg, (str, int, float, bool)): data_dict_list.append({"name": op_name + (f"[{arg_index}]" if len(args) > 1 else ""), "value": arg}) for key, value in kwargs.items(): - data_dict_list.append({"name": op_name + f" {key}", "value": value}) + if isinstance(value, dict): + value.update({"name": op_name + f" {key}"}) + data_dict_list.append(value) + else: + data_dict_list.append({"name": op_name + f" {key}", "value": value}) return data_dict_list diff --git a/op_tools/test/test_tool_with_special_op.py b/op_tools/test/test_tool_with_special_op.py index 0210716..4b55a13 100644 --- a/op_tools/test/test_tool_with_special_op.py +++ b/op_tools/test/test_tool_with_special_op.py @@ -196,6 +196,11 @@ def test_torch_tensor_device(self): x = torch.tensor(0, dtype=torch.int32, device="cuda") self.assertTrue(x.device.type == "cuda") + def test_input_is_output(self): + with op_tools.OpAutoCompare(): + x = torch.randn(3, 4, 5, dtype=torch.float32, device="cuda", requires_grad=False) + torch.exp(x, out=x) + if __name__ == "__main__": unittest.main() diff --git a/op_tools/utils.py b/op_tools/utils.py index 3db3ec5..5be6421 100644 --- a/op_tools/utils.py +++ b/op_tools/utils.py @@ -125,7 +125,11 @@ def is_opname_match(name, op_pattern=None): return False -def is_inplace_op(name): +def is_inplace_op(name, *args, **kwargs): + if kwargs.get("out", None) is not None and isinstance(kwargs["out"], torch.Tensor) and kwargs["out"] in args: + return True + if kwargs.get("inplace", False): + return True INPLACES_OP = ["torch.Tensor.__setitem__", "torch.Tensor.to", "torch.Tensor.contiguous", "torch.Tensor.to"] return name in INPLACES_OP or (name.endswith("_") and (not name.endswith("__")) and (name.startswith("torch.Tensor.")))