Skip to content

Commit

Permalink
add ditorch and op_tool config file (#71)
Browse files Browse the repository at this point in the history
* add ditorch and op_tool config file

* minor change

* Ignore framework custom operators in the recommended configuration

* set dtype used by cpu for autocompare for torch.exp torch.Tensor.std

* minor change
  • Loading branch information
zhaoguochun1995 authored Oct 25, 2024
1 parent 4e1494f commit b490ae0
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 2 deletions.
5 changes: 4 additions & 1 deletion ditorch/common_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@ def __init__(self):
super().__init__()

def __torch_function__(self, func, types, args, kwargs=None):
name = resolve_name(func)
try:
name = resolve_name(func)
except Exception:
name = None
result = func(*args, **(kwargs or {}))
if name == "torch.Tensor.device.__get__":
if result.type != "cpu":
Expand Down
59 changes: 59 additions & 0 deletions ditorch_config.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# export env_name=${env_name:-env_default_value}

export DITORCH_SHOW_DEVICE_AS_CUDA=${DITORCH_SHOW_DEVICE_AS_CUDA:-1}


export OP_TOOLS_PRINT_STACK=${OP_TOOLS_PRINT_STACK:-0}

export OP_TOOLS_MAX_CACHE_SIZE=${OP_TOOLS_MAX_CACHE_SIZE:-1000}


export OP_AUTOCOMPARE_DISABLE_LIST=${OP_AUTOCOMPARE_DISABLE_LIST:-"torch.rand,torch.randn,torch_mlu.*,torch_npu.*"}
export OP_AUTOCOMPARE_LIST=${OP_AUTOCOMPARE_LIST:-".*"}


export OP_CAPTURE_DISABLE_LIST=${OP_CAPTURE_DISABLE_LIST:-""}
export OP_CAPTURE_LIST=${OP_CAPTURE_LIST:-".*"}

export OP_DTYPE_CAST_DISABLE_LIST=${OP_DTYPE_CAST_DISABLE_LIST:-""}
export OP_DTYPE_CAST_LIST=${OP_DTYPE_CAST_LIST:-".*"}

export OP_FALLBACK_DISABLE_LIST=${OP_FALLBACK_DISABLE_LIST:-""}
export OP_FALLBACK_LIST=${OP_FALLBACK_LIST:-".*"}

export OP_OBSERVE_DISABLE_LIST=${OP_OBSERVE_DISABLE_LIST:-""}
export OP_OBSERVE_LIST=${OP_OBSERVE_LIST:-".*"}


export OP_OVERFLOW_CHECK_DISABLE_LIST=${OP_OVERFLOW_CHECK_DISABLE_LIST:-""}
export OP_OVERFLOW_CHECK_LIST=${OP_OVERFLOW_CHECK_LIST:-".*"}


export OP_TIME_MEASURE_DISABLE_LIST=${OP_TIME_MEASURE_DISABLE_LIST:-""}
export OP_TIME_MEASURE_LIST=${OP_TIME_MEASURE_LIST:-".*"}


# for autocompare and op_dtype_cast tools
# Set the dtype used by the CPU for autocompare
# Set the dtype used by the DEVICE for op_dtype_cast
# for special op
export LINEAR_OP_DTYPE_CAST_DICT=${LINEAR_OP_DTYPE_CAST_DICT:-"torch.float16->torch.float64,torch.bfloat16->torch.float64,torch.float32->torch.float64"}
export EMBEDDING_OP_DTYPE_CAST_DICT=${EMBEDDING_OP_DTYPE_CAST_DICT:-"torch.float16->torch.float64,torch.bfloat16->torch.float64,torch.float32->torch.float64"}
export NORMALIZE_OP_DTYPE_CAST_DICT=${NORMALIZE_OP_DTYPE_CAST_DICT:-"torch.float16->torch.float64,torch.bfloat16->torch.float64,torch.float32->torch.float64"}
export NORM_OP_DTYPE_CAST_DICT=${NORM_OP_DTYPE_CAST_DICT:-"torch.float16->torch.float64,torch.bfloat16->torch.float64,torch.float32->torch.float64"}
export CROSS_ENTROPY_OP_DTYPE_CAST_DICT=${CROSS_ENTROPY_OP_DTYPE_CAST_DICT:-"torch.float16->torch.float64,torch.bfloat16->torch.float64,torch.float32->torch.float64"}
export MUL_OP_DTYPE_CAST_DICT=${MUL_OP_DTYPE_CAST_DICT:-"torch.float16->torch.float64,torch.bfloat16->torch.float64,torch.float32->torch.float64"}
export MATMUL_OP_DTYPE_CAST_DICT=${MATMUL_OP_DTYPE_CAST_DICT:-"torch.float16->torch.float64,torch.bfloat16->torch.float64,torch.float32->torch.float64"}
export STD_OP_DTYPE_CAST_DICT=${STD_OP_DTYPE_CAST_DICT:-"torch.float16->torch.float64,torch.bfloat16->torch.float64,torch.float32->torch.float64"}
export EXP_OP_DTYPE_CAST_DICT=${EXP_OP_DTYPE_CAST_DICT:-"torch.float16->torch.float64,torch.bfloat16->torch.float64,torch.float32->torch.float64"}
# for generally op
export OP_DTYPE_CAST_DICT=${OP_DTYPE_CAST_DICT:-"torch.float16->torch.float32,torch.bfloat16->torch.float32"}


export AUTOCOMPARE_ERROR_TOLERANCE=${AUTOCOMPARE_ERROR_TOLERANCE:-"1e-3,1e-3"}
export AUTOCOMPARE_ERROR_TOLERANCE_FLOAT16=${AUTOCOMPARE_ERROR_TOLERANCE_FLOAT16:-"1e-4,1e-4"}
export AUTOCOMPARE_ERROR_TOLERANCE_BFLOAT16=${AUTOCOMPARE_ERROR_TOLERANCE_BFLOAT16:-"1e-3,1e-3"}
export AUTOCOMPARE_ERROR_TOLERANCE_FLOAT32=${AUTOCOMPARE_ERROR_TOLERANCE_FLOAT32:-"1e-5,1e-5"}
export AUTOCOMPARE_ERROR_TOLERANCE_FLOAT64=${AUTOCOMPARE_ERROR_TOLERANCE_FLOAT64:-"1e-8,1e-8"}

export LINEAR_AUTOCOMPARE_ERROR_TOLERANCE_BFLOAT16=${LINEAR_AUTOCOMPARE_ERROR_TOLERANCE_BFLOAT16:-"1e-2,1e-2"}
2 changes: 1 addition & 1 deletion op_tools/op_autocompare_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def count_params_with_requires_grad(self):

def compare_input_grad(self):
self.args_grad = self.grad_inputs_cpu
compare_info = compare_result(self.name + " grad", self.args_grad, self.args_cpu_grad)
compare_info = compare_result(self.name + " grad", list(self.args_grad), list(self.args_cpu_grad))
compare_info["forward_id"] = self.forward_op_id

compare_result_cache.append(self.forward_op_id, compare_info)
Expand Down
1 change: 1 addition & 0 deletions op_tools/test/test_opname_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def test_opname_match(self):
self.assertEqual(is_opname_match("torch.subc", "torch.add,.*"), True)
self.assertEqual(is_opname_match("torch.subc", None), True)
self.assertEqual(is_opname_match("torch.subc", ""), False)
self.assertEqual(is_opname_match("torch_mlu.fused_l2_norm", "torch_mlu.*"), True)
self.assertEqual(is_opname_match(None, "torch.add"), False)
self.assertEqual(is_opname_match(None, None), False)

Expand Down

0 comments on commit b490ae0

Please sign in to comment.