Skip to content

Commit

Permalink
Fixed the bug when the operator offline runner runs the inplace opera…
Browse files Browse the repository at this point in the history
…tor more than once
  • Loading branch information
zhaoguochun1995 committed Sep 14, 2024
1 parent 14b2c68 commit 2f29e6f
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 10 deletions.
6 changes: 3 additions & 3 deletions op_tools/op_autocompare_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ def register_grad_fn_hook(self, tensor):
hook_handle = None

def grad_fun(grad_inputs, grad_outputs):
hook_handle.remove()
self.compare_hook.run_backward_on_cpu(grad_inputs, grad_outputs)
self.compare_hook.compare_all_grad()
hook_handle.remove()

hook_handle = tensor.grad_fn.register_hook(grad_fun)
return grad_fun
Expand All @@ -79,9 +79,9 @@ def register_tensor_hook(self, index, tensor):
hook_handle = None

def grad_fun(grad):
hook_handle.remove()
self.compare_hook.set_input_grad(index, grad)
self.compare_hook.compare_all_grad()
hook_handle.remove()

hook_handle = tensor.register_hook(grad_fun)

Expand Down Expand Up @@ -334,7 +334,7 @@ def dump_all_autocompare_info():
with open(file_name, "w") as f:
f.write(data_string)
f.close
print(f"op elasped info saved to {file_name}")
print(f"op autocompare info saved to {file_name}")


atexit.register(dump_all_autocompare_info)
5 changes: 4 additions & 1 deletion op_tools/op_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import torch
import time
from .utils import to_device, get_function_from_string, traverse_container
from .utils import to_device, get_function_from_string, traverse_container, is_inplace_op
from .op_autocompare_hook import OpAutoCompareHook


Expand Down Expand Up @@ -147,6 +147,9 @@ def load_backward_data(self):
self.grad_outputs_cpu = None

def run_forward(self):
if is_inplace_op(self.name):
self.args = to_device("cuda", self.args_cpu)
self.kwargs = to_device("cuda", self.kwargs_cpu)
self.run_before_forward()
self.result = self.func(*self.args, **self.kwargs)
self.run_after_forward()
Expand Down
13 changes: 7 additions & 6 deletions op_tools/test/test_run_op_from_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@ def run_command_in_sub_process(commands):


if __name__ == "__main__":
data_file_dir = "op_tools_results/op_capture_result_raw"
shutil.copytree("op_tools_results/op_capture_results", data_file_dir, dirs_exist_ok=True)
raw_data_dir = "op_tools_results"
test_data_dir = "op_tools_results_test"
shutil.copytree(raw_data_dir, test_data_dir, dirs_exist_ok=True)

commands = f"python op_tools/run_op_from_data.py {data_file_dir} --sync_time_measure --run_times 10"
commands = f"python op_tools/run_op_from_data.py {test_data_dir} --sync_time_measure --run_times 10"
run_command_in_sub_process(commands)

commands = f"python op_tools/run_op_from_data.py {data_file_dir} --sync_time_measure --run_times 1 --acc_check"
commands = f"python op_tools/run_op_from_data.py {test_data_dir} --sync_time_measure --run_times 2 --acc_check"
run_command_in_sub_process(commands)

shutil.rmtree("op_tools_results/op_capture_results")
shutil.move(data_file_dir, "op_tools_results/op_capture_results")
shutil.rmtree(raw_data_dir)
shutil.move(test_data_dir, raw_data_dir)

0 comments on commit 2f29e6f

Please sign in to comment.