Skip to content

Commit

Permalink
Improve the OpCapture and other minor change (#54)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaoguochun1995 authored Oct 9, 2024
1 parent 9e50835 commit 949a690
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 42 deletions.
98 changes: 65 additions & 33 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,48 +53,80 @@ with op_tools.OpCapture():
#### **抓取前向和反向的所有输入输出**

```
op_capture_result/0/2024-08-06--11-41/torch.Tensor.to/8/input.pth saved
op_capture_result/0/2024-08-06--11-41/torch.Tensor.to/8/output.pth saved
apply OpCaptureHook on torch.Tensor.mul
op_capture_result/0/2024-08-06--11-41/torch.Tensor.mul/9/input.pth saved
op_capture_result/0/2024-08-06--11-41/torch.Tensor.mul/9/output.pth saved
...
apply OpCaptureHook on torch.Tensor.add
op_capture_result/0/2024-08-06--11-41/torch.Tensor.add/10/input.pth saved
op_capture_result/0/2024-08-06--11-41/torch.Tensor.add/10/output.pth saved
apply OpCaptureHook on torch.Tensor.sub
op_capture_result/0/2024-08-06--11-41/torch.Tensor.sub/11/input.pth saved
op_capture_result/0/2024-08-06--11-41/torch.Tensor.sub/11/output.pth saved
apply OpCaptureHook on torch.Tensor.div
op_capture_result/0/2024-08-06--11-41/torch.Tensor.div/12/input.pth saved
op_capture_result/0/2024-08-06--11-41/torch.Tensor.div/12/output.pth saved
apply OpCaptureHook on torch.Tensor.sort
op_capture_result/0/2024-08-06--11-41/torch.Tensor.sort/13/input.pth saved
op_capture_result/0/2024-08-06--11-41/torch.Tensor.sort/13/output.pth saved
apply OpCaptureHook on torch.Tensor.sum
op_capture_result/0/2024-08-06--11-41/torch.Tensor.sum/14/input.pth saved
op_capture_result/0/2024-08-06--11-41/torch.Tensor.sum/14/output.pth saved
skip OpCaptureHook on torch.Tensor.backward
op_capture_result/0/2024-08-06--11-41/torch.Tensor.sum/14/grad_inputs.pth saved
op_capture_result/0/2024-08-06--11-41/torch.Tensor.sum/14/grad_outputs.pth saved
op_capture_result/0/2024-08-06--11-41/torch.Tensor.sort/13/grad_inputs.pth saved
op_capture_result/0/2024-08-06--11-41/torch.Tensor.sort/13/grad_outputs.pth saved
op_capture_result/0/2024-08-06--11-41/torch.Tensor.to/8/grad_inputs.pth saved
op_capture_result/0/2024-08-06--11-41/torch.Tensor.to/8/grad_outputs.pth saved
op_tools_results/op_capture_results/torch.Tensor.add/283699/161/2024-10-09-11-42-15/input.pth saved
op_tools_results/op_capture_results/torch.Tensor.add/283699/161/2024-10-09-11-42-15/output.pth saved
torch.Tensor.add forward_id:161/2024-10-09-11-42-15 /deeplink_afs/zhaoguochun/SmallModelOptimize/InternTrain/internlm/model/ops/norm.py:14 manual_rms_norm: my_input = my_input * torch.rsqrt(variance + eps)
+----------------------------+--------+---------+-------+---------------+---------------+---------------+---------------+----------------+-------+
| name | device | dtype | numel | shape | stride | requires_grad | layout | data_ptr | value |
+----------------------------+--------+---------+-------+---------------+---------------+---------------+---------------+----------------+-------+
| torch.Tensor.add inputs[0] | npu:0 | float32 | 16384 | (1, 16384, 1) | (16384, 1, 1) | False | torch.strided | 20067180408320 | |
| torch.Tensor.add inputs[1] | | | | | | | | | 1e-05 |
| torch.Tensor.add outputs | npu:0 | float32 | 16384 | (1, 16384, 1) | (16384, 1, 1) | False | torch.strided | 20067180474368 | |
+----------------------------+--------+---------+-------+---------------+---------------+---------------+---------------+----------------+-------+
apply OpCaptureHook on torch.rsqrt
op_tools_results/op_capture_results/torch.rsqrt/283699/162/2024-10-09-11-42-15/input.pth saved
op_tools_results/op_capture_results/torch.rsqrt/283699/162/2024-10-09-11-42-15/output.pth saved
torch.rsqrt forward_id:162/2024-10-09-11-42-15 /deeplink_afs/zhaoguochun/SmallModelOptimize/InternTrain/internlm/model/ops/norm.py:14 manual_rms_norm: my_input = my_input * torch.rsqrt(variance + eps)
+---------------------+--------+---------+-------+---------------+---------------+---------------+---------------+----------------+
| name | device | dtype | numel | shape | stride | requires_grad | layout | data_ptr |
+---------------------+--------+---------+-------+---------------+---------------+---------------+---------------+----------------+
| torch.rsqrt inputs | npu:0 | float32 | 16384 | (1, 16384, 1) | (16384, 1, 1) | False | torch.strided | 20067180474368 |
| torch.rsqrt outputs | npu:0 | float32 | 16384 | (1, 16384, 1) | (16384, 1, 1) | False | torch.strided | 20067180540416 |
+---------------------+--------+---------+-------+---------------+---------------+---------------+---------------+----------------+
apply OpCaptureHook on torch.Tensor.mul
op_tools_results/op_capture_results/torch.Tensor.mul/283699/163/2024-10-09-11-42-15/input.pth saved
op_tools_results/op_capture_results/torch.Tensor.mul/283699/163/2024-10-09-11-42-15/output.pth saved
torch.Tensor.mul forward_id:163/2024-10-09-11-42-15 /deeplink_afs/zhaoguochun/SmallModelOptimize/InternTrain/internlm/model/ops/norm.py:14 manual_rms_norm: my_input = my_input * torch.rsqrt(variance + eps)
+----------------------------+--------+----------+----------+------------------+---------------------+---------------+---------------+----------------+
| name | device | dtype | numel | shape | stride | requires_grad | layout | data_ptr |
+----------------------------+--------+----------+----------+------------------+---------------------+---------------+---------------+----------------+
| torch.Tensor.mul inputs[0] | npu:0 | bfloat16 | 33554432 | (1, 16384, 2048) | (33554432, 2048, 1) | True | torch.strided | 20074677141504 |
| torch.Tensor.mul inputs[1] | npu:0 | float32 | 16384 | (1, 16384, 1) | (16384, 1, 1) | False | torch.strided | 20067180540416 |
| torch.Tensor.mul outputs | npu:0 | float32 | 33554432 | (1, 16384, 2048) | (33554432, 2048, 1) | False | torch.strided | 20075012687360 |
+----------------------------+--------+----------+----------+------------------+---------------------+---------------+---------------+----------------+
...
```

#### **只抓取sort算子的参数,忽略其他算子 OP_CAPTURE_LIST=torch.Tensor.sort**
```
...
skip OpCaptureHook on torch.Tensor.mul
skip OpCaptureHook on torch.Tensor.add
skip OpCaptureHook on torch.Tensor.sub
skip OpCaptureHook on torch.Tensor.div
apply OpCaptureHook on torch.Tensor.sort
op_capture_result/0/2024-08-06--11-41/torch.Tensor.sort/34/input.pth saved
op_capture_result/0/2024-08-06--11-41/torch.Tensor.sort/34/output.pth saved
op_capture_result/0/2024-08-06--11-41/torch.Tensor.sort/34/grad_inputs.pth saved
op_capture_result/0/2024-08-06--11-41/torch.Tensor.sort/34/grad_outputs.pth saved
op_tools_results/op_capture_results/torch.Tensor.sort/3834328/59/2024-10-09-11-40-14/input.pth saved
op_tools_results/op_capture_results/torch.Tensor.sort/3834328/59/2024-10-09-11-40-14/output.pth saved
torch.Tensor.sort forward_id:59/2024-10-09-11-40-14 /deeplink_afs/zhaoguochun/ditorch2/op_tools/test/test_op_capture.py:15 f: sorted, indices = e.sort() # return torch.return_type.sort
+----------------------------------+--------+---------+-------+----------+---------+---------------+---------------+----------------+
| name | device | dtype | numel | shape | stride | requires_grad | layout | data_ptr |
+----------------------------------+--------+---------+-------+----------+---------+---------------+---------------+----------------+
| torch.Tensor.sort inputs | npu:0 | float32 | 200 | (10, 20) | (20, 1) | True | torch.strided | 20067179830784 |
| torch.Tensor.sort outputs [0][0] | npu:0 | float32 | 200 | (10, 20) | (20, 1) | True | torch.strided | 20067179831808 |
| torch.Tensor.sort outputs [0][1] | npu:0 | int64 | 200 | (10, 20) | (20, 1) | False | torch.strided | 20067179832832 |
+----------------------------------+--------+---------+-------+----------+---------+---------------+---------------+----------------+
skip OpCaptureHook on torch.Tensor.__getitem__
skip OpCaptureHook on torch.Tensor.sum
op_tools_results/op_capture_results/torch.Tensor.sort/3834328/59/2024-10-09-11-40-14/grad_inputs.pth saved
op_tools_results/op_capture_results/torch.Tensor.sort/3834328/59/2024-10-09-11-40-14/grad_outputs.pth saved
torch.Tensor.sort forward_id:<built-in function id>
+-------------------------------+--------+---------+-------+----------+---------+---------------+---------------+----------------+
| name | device | dtype | numel | shape | stride | requires_grad | layout | data_ptr |
+-------------------------------+--------+---------+-------+----------+---------+---------------+---------------+----------------+
| torch.Tensor.sort grad_output | npu:0 | float32 | 200 | (10, 20) | (20, 1) | False | torch.strided | 20067179835904 |
| torch.Tensor.sort grad_inputs | npu:0 | float32 | 200 | (10, 20) | (20, 1) | False | torch.strided | 20067179836928 |
+-------------------------------+--------+---------+-------+----------+---------+---------------+---------------+----------------+
...
```

Expand Down
2 changes: 1 addition & 1 deletion op_tools/op_autocompare_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def compare_backward_relate(self):

id = self.forward_op_id
self = None
garbage_collect(id, 10)
garbage_collect(id, 2)

def save_forward_args(self):
save_op_args(
Expand Down
33 changes: 29 additions & 4 deletions op_tools/op_capture_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,45 @@
import torch
import time
from .base_hook import BaseHook, DisableHookGuard
from .utils import traverse_container, is_opname_match
from .save_op_args import save_op_args
from .utils import traverse_container, is_opname_match, garbage_collect
from .save_op_args import save_op_args, serialize_args_to_dict
from .pretty_print import dict_data_list_to_table, packect_data_to_dict_list


class BackwardHookHandle:
def __init__(self, name, id) -> None:
self.name = name
self.id = id

def grad_fun_hook(self):
def register_grad_fun_hook(self, tensor):
hook_handle = None

def grad_fun(grad_inputs, grad_outputs):
hook_handle.remove()

save_op_args(self.name, f"{self.id}/grad_inputs", *tuple(grad_inputs))
save_op_args(self.name, f"{self.id}/grad_outputs", *tuple(grad_outputs))

grad_output_list = packect_data_to_dict_list(self.name + " grad_output", serialize_args_to_dict(*grad_outputs))
grad_inputs_list = packect_data_to_dict_list(self.name + " grad_inputs", serialize_args_to_dict(*grad_inputs))
backward_args_table = dict_data_list_to_table(grad_output_list + grad_inputs_list)
print(f"{self.name} forward_id:{id}\n{backward_args_table}", "\n" * 4)

hook_handle = tensor.grad_fn.register_hook(grad_fun)

return grad_fun


class OpCaptureHook(BaseHook):
def __init__(self, name, func) -> None:
super().__init__(name, func)

def op_forward_args_to_table(self):
inputs_list = packect_data_to_dict_list(self.name + " inputs", serialize_args_to_dict(*self.args, **self.kwargs))
output_list = packect_data_to_dict_list(self.name + " outputs", serialize_args_to_dict(self.result))
forward_args_table = dict_data_list_to_table(inputs_list + output_list)
return forward_args_table

def before_call_op(self, *args, **kwargs):
self.forward_op_id = f"{self.id}/{time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime())}"
with DisableHookGuard():
Expand All @@ -39,12 +57,19 @@ def after_call_op(self, result):
id = f"{self.forward_op_id}/output"
save_op_args(self.name, id, self.result)

table = self.op_forward_args_to_table()
print(f"{self.name} forward_id:{self.forward_op_id} {self.current_location} \n{table}", "\n"*4)

self.backward_hook_handle = BackwardHookHandle(self.name, self.forward_op_id)

for result in traverse_container(self.result):
if isinstance(result, torch.Tensor):
if result.grad_fn is not None:
result.grad_fn.register_hook(self.backward_hook_handle.grad_fun_hook())
self.backward_hook_handle.register_grad_fun_hook(result)

id = self.id
self = None
garbage_collect(id)

def is_should_apply(self, *args, **kwargs):
if is_opname_match(self.name, os.getenv("OP_CAPTURE_DISABLE_LIST", "")):
Expand Down
1 change: 0 additions & 1 deletion op_tools/op_dtype_cast_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def before_call_op(self, *args, **kwargs):
for i in range(len(self.ins_list)):
if isinstance(self.ins_list[i], torch.Tensor):
if self.ins_list[i].dtype != self.raw_ins_list[i].dtype:
# print(f"OpDtypeCastHook: {self.name:<50} {i}th arg {self.raw_ins_list[i].dtype} -> {self.ins_list[i].dtype} config:{self.dtype_cast_config_str}") # noqa: E501
self.dtype_cast_back_dict[self.ins_list[i].dtype] = self.raw_ins_list[i].dtype
data_dict = {
"name": self.name,
Expand Down
6 changes: 3 additions & 3 deletions op_tools/save_op_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
def serialize_args_to_dict(*args, **kwargs):
def tensor_to_dict(tensor):
return {
"device": str(tensor.device),
"dtype": str(tensor.dtype).split(".")[-1],
"numel": tensor.numel(),
"shape": str(tuple(tensor.shape)),
"stride": tensor.stride(),
"numel": tensor.numel(),
"dtype": str(tensor.dtype).split(".")[-1],
"device": str(tensor.device),
"requires_grad": tensor.requires_grad,
"layout": str(tensor.layout),
"data_ptr": tensor.data_ptr(),
Expand Down

0 comments on commit 949a690

Please sign in to comment.