Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zgc/ditorch support overflow apply to third party interface #73

Merged

Conversation

zhaoguochun1995
Copy link
Collaborator

@zhaoguochun1995 zhaoguochun1995 commented Oct 25, 2024

  1. 精度溢出检测工具支持作用在非torch接口上
  2. 完善readme
  3. 新增配置管理(终端中打印出相关的配置,用户对可配置选项能更清楚)
>>> import op_tools
>>>        
>>> autocompare = op_tools.OpAutoCompare()
>>> autocompare.start()
>>> import torch
>>> x = torch.randn(3,4,device="cuda")
skip OpAutoCompareHook on torch.randn
>>> 
>>> y = x + x
option: OP_AUTOCOMPARE_DISABLE_LIST=torch.rand,torch.randn,torch_mlu.*,torch_npu.*
option: OP_AUTOCOMPARE_LIST=.*
option: OP_TOOLS_PRINT_STACK=0
apply OpAutoCompareHook on torch.Tensor.add
option: OP_DTYPE_CAST_DICT=torch.float16->torch.float32,torch.bfloat16->torch.float32
option: AUTOCOMPARE_ERROR_TOLERANCE_FLOAT32=1e-5,1e-5
option: OP_TOOLS_MAX_CACHE_SIZE=1000

autocompare    torch.Tensor.add forward_id: 1    
<stdin>:1 <module>: 
+---------------------------------+--------+---------+-------+--------+--------+---------------+---------------+----------------+
|               name              | device |  dtype  | numel | shape  | stride | requires_grad |     layout    |    data_ptr    |
+---------------------------------+--------+---------+-------+--------+--------+---------------+---------------+----------------+
|    torch.Tensor.add inputs[0]   | cuda:0 | float32 |   12  | (3, 4) | (4, 1) |     False     | torch.strided | 20067179823104 |
|    torch.Tensor.add inputs[1]   | cuda:0 | float32 |   12  | (3, 4) | (4, 1) |     False     | torch.strided | 20067179823104 |
|     torch.Tensor.add outputs    | cuda:0 | float32 |   12  | (3, 4) | (4, 1) |     False     | torch.strided | 20067179823616 |
| torch.Tensor.add inputs(cpu)[0] |  cpu   | float32 |   12  | (3, 4) | (4, 1) |     False     | torch.strided |   546780864    |
| torch.Tensor.add inputs(cpu)[1] |  cpu   | float32 |   12  | (3, 4) | (4, 1) |     False     | torch.strided |   546827200    |
|  torch.Tensor.add outputs(cpu)  |  cpu   | float32 |   12  | (3, 4) | (4, 1) |     False     | torch.strided |   556165184    |
+---------------------------------+--------+---------+-------+--------+--------+---------------+---------------+----------------+
+--------------------------------+----------+-------------------+--------------+-------------------+-------------+-------------+------------+
|              name              | allclose | cosine_similarity | max_abs_diff | max_relative_diff |     atol    |     rtol    | error_info |
+--------------------------------+----------+-------------------+--------------+-------------------+-------------+-------------+------------+
| torch.Tensor.add input[0]      |   True   |    1.000000000    | 0.000000000  |    0.000000000    | 0.000010000 | 0.000010000 |            |
| torch.Tensor.add input[1]      |   True   |    1.000000000    | 0.000000000  |    0.000000000    | 0.000010000 | 0.000010000 |            |
| torch.Tensor.add output        |   True   |    1.000000000    | 0.000000000  |    0.000000000    | 0.000010000 | 0.000010000 |            |
+--------------------------------+----------+-------------------+--------------+-------------------+-------------+-------------+------------+



>>> 

@zhaoguochun1995 zhaoguochun1995 force-pushed the zgc/ditorch_support_overflow_apply_to_third_party_interface branch from 727948c to 59384ff Compare October 28, 2024 06:18
@yangbofun yangbofun merged commit fb14343 into main Oct 28, 2024
13 checks passed
@yangbofun yangbofun deleted the zgc/ditorch_support_overflow_apply_to_third_party_interface branch October 28, 2024 07:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants