Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve the OpCapture and other minor change #54

Merged
merged 1 commit into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 65 additions & 33 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,48 +53,80 @@ with op_tools.OpCapture():
#### **抓取前向和反向的所有输入输出**

```
op_capture_result/0/2024-08-06--11-41/torch.Tensor.to/8/input.pth saved
op_capture_result/0/2024-08-06--11-41/torch.Tensor.to/8/output.pth saved
apply OpCaptureHook on torch.Tensor.mul
op_capture_result/0/2024-08-06--11-41/torch.Tensor.mul/9/input.pth saved
op_capture_result/0/2024-08-06--11-41/torch.Tensor.mul/9/output.pth saved
...
apply OpCaptureHook on torch.Tensor.add
op_capture_result/0/2024-08-06--11-41/torch.Tensor.add/10/input.pth saved
op_capture_result/0/2024-08-06--11-41/torch.Tensor.add/10/output.pth saved
apply OpCaptureHook on torch.Tensor.sub
op_capture_result/0/2024-08-06--11-41/torch.Tensor.sub/11/input.pth saved
op_capture_result/0/2024-08-06--11-41/torch.Tensor.sub/11/output.pth saved
apply OpCaptureHook on torch.Tensor.div
op_capture_result/0/2024-08-06--11-41/torch.Tensor.div/12/input.pth saved
op_capture_result/0/2024-08-06--11-41/torch.Tensor.div/12/output.pth saved
apply OpCaptureHook on torch.Tensor.sort
op_capture_result/0/2024-08-06--11-41/torch.Tensor.sort/13/input.pth saved
op_capture_result/0/2024-08-06--11-41/torch.Tensor.sort/13/output.pth saved
apply OpCaptureHook on torch.Tensor.sum
op_capture_result/0/2024-08-06--11-41/torch.Tensor.sum/14/input.pth saved
op_capture_result/0/2024-08-06--11-41/torch.Tensor.sum/14/output.pth saved
skip OpCaptureHook on torch.Tensor.backward
op_capture_result/0/2024-08-06--11-41/torch.Tensor.sum/14/grad_inputs.pth saved
op_capture_result/0/2024-08-06--11-41/torch.Tensor.sum/14/grad_outputs.pth saved
op_capture_result/0/2024-08-06--11-41/torch.Tensor.sort/13/grad_inputs.pth saved
op_capture_result/0/2024-08-06--11-41/torch.Tensor.sort/13/grad_outputs.pth saved
op_capture_result/0/2024-08-06--11-41/torch.Tensor.to/8/grad_inputs.pth saved
op_capture_result/0/2024-08-06--11-41/torch.Tensor.to/8/grad_outputs.pth saved
op_tools_results/op_capture_results/torch.Tensor.add/283699/161/2024-10-09-11-42-15/input.pth saved
op_tools_results/op_capture_results/torch.Tensor.add/283699/161/2024-10-09-11-42-15/output.pth saved
torch.Tensor.add forward_id:161/2024-10-09-11-42-15 /deeplink_afs/zhaoguochun/SmallModelOptimize/InternTrain/internlm/model/ops/norm.py:14 manual_rms_norm: my_input = my_input * torch.rsqrt(variance + eps)
+----------------------------+--------+---------+-------+---------------+---------------+---------------+---------------+----------------+-------+
| name | device | dtype | numel | shape | stride | requires_grad | layout | data_ptr | value |
+----------------------------+--------+---------+-------+---------------+---------------+---------------+---------------+----------------+-------+
| torch.Tensor.add inputs[0] | npu:0 | float32 | 16384 | (1, 16384, 1) | (16384, 1, 1) | False | torch.strided | 20067180408320 | |
| torch.Tensor.add inputs[1] | | | | | | | | | 1e-05 |
| torch.Tensor.add outputs | npu:0 | float32 | 16384 | (1, 16384, 1) | (16384, 1, 1) | False | torch.strided | 20067180474368 | |
+----------------------------+--------+---------+-------+---------------+---------------+---------------+---------------+----------------+-------+
apply OpCaptureHook on torch.rsqrt
op_tools_results/op_capture_results/torch.rsqrt/283699/162/2024-10-09-11-42-15/input.pth saved
op_tools_results/op_capture_results/torch.rsqrt/283699/162/2024-10-09-11-42-15/output.pth saved
torch.rsqrt forward_id:162/2024-10-09-11-42-15 /deeplink_afs/zhaoguochun/SmallModelOptimize/InternTrain/internlm/model/ops/norm.py:14 manual_rms_norm: my_input = my_input * torch.rsqrt(variance + eps)
+---------------------+--------+---------+-------+---------------+---------------+---------------+---------------+----------------+
| name | device | dtype | numel | shape | stride | requires_grad | layout | data_ptr |
+---------------------+--------+---------+-------+---------------+---------------+---------------+---------------+----------------+
| torch.rsqrt inputs | npu:0 | float32 | 16384 | (1, 16384, 1) | (16384, 1, 1) | False | torch.strided | 20067180474368 |
| torch.rsqrt outputs | npu:0 | float32 | 16384 | (1, 16384, 1) | (16384, 1, 1) | False | torch.strided | 20067180540416 |
+---------------------+--------+---------+-------+---------------+---------------+---------------+---------------+----------------+
apply OpCaptureHook on torch.Tensor.mul
op_tools_results/op_capture_results/torch.Tensor.mul/283699/163/2024-10-09-11-42-15/input.pth saved
op_tools_results/op_capture_results/torch.Tensor.mul/283699/163/2024-10-09-11-42-15/output.pth saved
torch.Tensor.mul forward_id:163/2024-10-09-11-42-15 /deeplink_afs/zhaoguochun/SmallModelOptimize/InternTrain/internlm/model/ops/norm.py:14 manual_rms_norm: my_input = my_input * torch.rsqrt(variance + eps)
+----------------------------+--------+----------+----------+------------------+---------------------+---------------+---------------+----------------+
| name | device | dtype | numel | shape | stride | requires_grad | layout | data_ptr |
+----------------------------+--------+----------+----------+------------------+---------------------+---------------+---------------+----------------+
| torch.Tensor.mul inputs[0] | npu:0 | bfloat16 | 33554432 | (1, 16384, 2048) | (33554432, 2048, 1) | True | torch.strided | 20074677141504 |
| torch.Tensor.mul inputs[1] | npu:0 | float32 | 16384 | (1, 16384, 1) | (16384, 1, 1) | False | torch.strided | 20067180540416 |
| torch.Tensor.mul outputs | npu:0 | float32 | 33554432 | (1, 16384, 2048) | (33554432, 2048, 1) | False | torch.strided | 20075012687360 |
+----------------------------+--------+----------+----------+------------------+---------------------+---------------+---------------+----------------+
...
```

#### **只抓取sort算子的参数,忽略其他算子 OP_CAPTURE_LIST=torch.Tensor.sort**
```
...
skip OpCaptureHook on torch.Tensor.mul
skip OpCaptureHook on torch.Tensor.add
skip OpCaptureHook on torch.Tensor.sub
skip OpCaptureHook on torch.Tensor.div
apply OpCaptureHook on torch.Tensor.sort
op_capture_result/0/2024-08-06--11-41/torch.Tensor.sort/34/input.pth saved
op_capture_result/0/2024-08-06--11-41/torch.Tensor.sort/34/output.pth saved
op_capture_result/0/2024-08-06--11-41/torch.Tensor.sort/34/grad_inputs.pth saved
op_capture_result/0/2024-08-06--11-41/torch.Tensor.sort/34/grad_outputs.pth saved
op_tools_results/op_capture_results/torch.Tensor.sort/3834328/59/2024-10-09-11-40-14/input.pth saved
op_tools_results/op_capture_results/torch.Tensor.sort/3834328/59/2024-10-09-11-40-14/output.pth saved
torch.Tensor.sort forward_id:59/2024-10-09-11-40-14 /deeplink_afs/zhaoguochun/ditorch2/op_tools/test/test_op_capture.py:15 f: sorted, indices = e.sort() # return torch.return_type.sort
+----------------------------------+--------+---------+-------+----------+---------+---------------+---------------+----------------+
| name | device | dtype | numel | shape | stride | requires_grad | layout | data_ptr |
+----------------------------------+--------+---------+-------+----------+---------+---------------+---------------+----------------+
| torch.Tensor.sort inputs | npu:0 | float32 | 200 | (10, 20) | (20, 1) | True | torch.strided | 20067179830784 |
| torch.Tensor.sort outputs [0][0] | npu:0 | float32 | 200 | (10, 20) | (20, 1) | True | torch.strided | 20067179831808 |
| torch.Tensor.sort outputs [0][1] | npu:0 | int64 | 200 | (10, 20) | (20, 1) | False | torch.strided | 20067179832832 |
+----------------------------------+--------+---------+-------+----------+---------+---------------+---------------+----------------+
skip OpCaptureHook on torch.Tensor.__getitem__
skip OpCaptureHook on torch.Tensor.sum
op_tools_results/op_capture_results/torch.Tensor.sort/3834328/59/2024-10-09-11-40-14/grad_inputs.pth saved
op_tools_results/op_capture_results/torch.Tensor.sort/3834328/59/2024-10-09-11-40-14/grad_outputs.pth saved
torch.Tensor.sort forward_id:<built-in function id>
+-------------------------------+--------+---------+-------+----------+---------+---------------+---------------+----------------+
| name | device | dtype | numel | shape | stride | requires_grad | layout | data_ptr |
+-------------------------------+--------+---------+-------+----------+---------+---------------+---------------+----------------+
| torch.Tensor.sort grad_output | npu:0 | float32 | 200 | (10, 20) | (20, 1) | False | torch.strided | 20067179835904 |
| torch.Tensor.sort grad_inputs | npu:0 | float32 | 200 | (10, 20) | (20, 1) | False | torch.strided | 20067179836928 |
+-------------------------------+--------+---------+-------+----------+---------+---------------+---------------+----------------+
...
```

Expand Down
2 changes: 1 addition & 1 deletion op_tools/op_autocompare_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def compare_backward_relate(self):

id = self.forward_op_id
self = None
garbage_collect(id, 10)
garbage_collect(id, 2)

def save_forward_args(self):
save_op_args(
Expand Down
33 changes: 29 additions & 4 deletions op_tools/op_capture_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,45 @@
import torch
import time
from .base_hook import BaseHook, DisableHookGuard
from .utils import traverse_container, is_opname_match
from .save_op_args import save_op_args
from .utils import traverse_container, is_opname_match, garbage_collect
from .save_op_args import save_op_args, serialize_args_to_dict
from .pretty_print import dict_data_list_to_table, packect_data_to_dict_list


class BackwardHookHandle:
def __init__(self, name, id) -> None:
self.name = name
self.id = id

def grad_fun_hook(self):
def register_grad_fun_hook(self, tensor):
hook_handle = None

def grad_fun(grad_inputs, grad_outputs):
hook_handle.remove()

save_op_args(self.name, f"{self.id}/grad_inputs", *tuple(grad_inputs))
save_op_args(self.name, f"{self.id}/grad_outputs", *tuple(grad_outputs))

grad_output_list = packect_data_to_dict_list(self.name + " grad_output", serialize_args_to_dict(*grad_outputs))
grad_inputs_list = packect_data_to_dict_list(self.name + " grad_inputs", serialize_args_to_dict(*grad_inputs))
backward_args_table = dict_data_list_to_table(grad_output_list + grad_inputs_list)
print(f"{self.name} forward_id:{id}\n{backward_args_table}", "\n" * 4)

hook_handle = tensor.grad_fn.register_hook(grad_fun)

return grad_fun


class OpCaptureHook(BaseHook):
def __init__(self, name, func) -> None:
super().__init__(name, func)

def op_forward_args_to_table(self):
inputs_list = packect_data_to_dict_list(self.name + " inputs", serialize_args_to_dict(*self.args, **self.kwargs))
output_list = packect_data_to_dict_list(self.name + " outputs", serialize_args_to_dict(self.result))
forward_args_table = dict_data_list_to_table(inputs_list + output_list)
return forward_args_table

def before_call_op(self, *args, **kwargs):
self.forward_op_id = f"{self.id}/{time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime())}"
with DisableHookGuard():
Expand All @@ -39,12 +57,19 @@ def after_call_op(self, result):
id = f"{self.forward_op_id}/output"
save_op_args(self.name, id, self.result)

table = self.op_forward_args_to_table()
print(f"{self.name} forward_id:{self.forward_op_id} {self.current_location} \n{table}", "\n"*4)

self.backward_hook_handle = BackwardHookHandle(self.name, self.forward_op_id)

for result in traverse_container(self.result):
if isinstance(result, torch.Tensor):
if result.grad_fn is not None:
result.grad_fn.register_hook(self.backward_hook_handle.grad_fun_hook())
self.backward_hook_handle.register_grad_fun_hook(result)

id = self.id
self = None
garbage_collect(id)

def is_should_apply(self, *args, **kwargs):
if is_opname_match(self.name, os.getenv("OP_CAPTURE_DISABLE_LIST", "")):
Expand Down
1 change: 0 additions & 1 deletion op_tools/op_dtype_cast_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def before_call_op(self, *args, **kwargs):
for i in range(len(self.ins_list)):
if isinstance(self.ins_list[i], torch.Tensor):
if self.ins_list[i].dtype != self.raw_ins_list[i].dtype:
# print(f"OpDtypeCastHook: {self.name:<50} {i}th arg {self.raw_ins_list[i].dtype} -> {self.ins_list[i].dtype} config:{self.dtype_cast_config_str}") # noqa: E501
self.dtype_cast_back_dict[self.ins_list[i].dtype] = self.raw_ins_list[i].dtype
data_dict = {
"name": self.name,
Expand Down
6 changes: 3 additions & 3 deletions op_tools/save_op_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
def serialize_args_to_dict(*args, **kwargs):
def tensor_to_dict(tensor):
return {
"device": str(tensor.device),
"dtype": str(tensor.dtype).split(".")[-1],
"numel": tensor.numel(),
"shape": str(tuple(tensor.shape)),
"stride": tensor.stride(),
"numel": tensor.numel(),
"dtype": str(tensor.dtype).split(".")[-1],
"device": str(tensor.device),
"requires_grad": tensor.requires_grad,
"layout": str(tensor.layout),
"data_ptr": tensor.data_ptr(),
Expand Down