Skip to content

Commit

Permalink
refactor ditorch
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaoguochun1995 committed Oct 30, 2024
1 parent ccc9ed8 commit bcb0a6f
Show file tree
Hide file tree
Showing 24 changed files with 248 additions and 218 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/runs_on_ascend.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,7 @@ jobs:
cd ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER} && rm -rf ${GITHUB_JOB} && cp -R Build ${GITHUB_JOB} && cd ${GITHUB_JOB}
export PYTHONPATH=${PYTHONPATH}:$PWD
echo "start to test"
bash ci/run_mock_ops_test.sh npu
bash ci/run_op_tools_test_cases.sh
bash ci/run_individual_test_cases.sh
Test_use_pytorch_test_case:
name: run pytorch test case on ascend
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/runs_on_camb.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ jobs:
echo "start to test"
source /mnt/cache/share/platform/env/ditorch_env
export PYTHONPATH=${PYTHONPATH}:$PWD
srun -p camb_mlu370_m8 -n 1 --gres=mlu:1 bash ci/run_op_tools_test_cases.sh
srun -p camb_mlu370_m8 -n 1 --gres=mlu:1 bash ci/run_individual_test_cases.sh
Test_use_pytorch_test_case:
name: run pytorch test case on camb
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/runs_on_nv.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ jobs:
set -ex
cd ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER} && rm -rf ${GITHUB_JOB} && cp -R Build ${GITHUB_JOB} && cd ${GITHUB_JOB}
echo "start to test"
srun --job-name=${GITHUB_JOB} bash -c "source /mnt/cache/share/platform/env/ditorch_env && cd ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/${GITHUB_JOB} && export PYTHONPATH=${PYTHONPATH}:$PWD && bash ci/run_op_tools_test_cases.sh"
srun --job-name=${GITHUB_JOB} bash -c "source /mnt/cache/share/platform/env/ditorch_env && cd ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/${GITHUB_JOB} && export PYTHONPATH=${PYTHONPATH}:$PWD && bash ci/run_individual_test_cases.sh"
Test_use_pytorch_test_case:
name: run pytorch test case on nv
Expand Down
11 changes: 11 additions & 0 deletions ci/run_individual_test_cases.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
date
find . -name test*.py | xargs -I {} bash -c ' echo "start run {}";date;time python {} && echo "Test {} PASSED\n\n\n" || echo "Test {} FAILED\n\n\n"' 2>&1 | tee test_individual_cases.log

# Check if any tests failed
if grep -Eq "FAILED" test_individual_cases.log; then
echo "tests failed"
exit 1
else
echo "all tests passed"
exit 0
fi
21 changes: 0 additions & 21 deletions ci/run_mock_ops_test.sh

This file was deleted.

11 changes: 0 additions & 11 deletions ci/run_op_tools_test_cases.sh

This file was deleted.

35 changes: 16 additions & 19 deletions ditorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,33 @@
# Copyright (c) 2024, DeepLink.
import os

framework = None
adapter = None
try:
from ditorch import torch_npu_adapter

framework = "torch_npu:" + torch_npu_adapter.torch_npu.__version__
except Exception as e: # noqa: F841
import ditorch.torch_npu_adapter as adapter # noqa: F811
except ImportError as e: # noqa: F841
pass
try:
from ditorch import torch_dipu_adapter # noqa: F401

framework = "torch_dipu" # torch_dipu has not __version__ attr

except Exception as e: # noqa: F841
pass

try:
from ditorch import torch_mlu_adapter
import ditorch.torch_dipu_adapter as adapter # noqa: F811

framework = "torch_mlu:" + torch_mlu_adapter.torch_mlu.__version__
except Exception as e: # noqa: F841
except ImportError as e: # noqa: F841
pass

try:
from ditorch import torch_biren_adapter
import ditorch.torch_mlu_adapter as adapter # noqa: F811
except ImportError as e: # noqa: F841
pass

framework = "torch_br:" + torch_biren_adapter.torch_br.__version__
except Exception as e: # noqa: F841
try:
import ditorch.torch_biren_adapter as adapter # noqa: F811
except ImportError as e: # noqa: F841
pass


from ditorch import common_adapter # noqa: F401,E402

print(f"ditorch.framework: {framework} pid: {os.getpid()}")
if adapter is not None:
adapter.mock()
common_adapter.mock_common()

print(f"ditorch: {adapter.arch} {adapter.framework.__name__}:{adapter.framework.__version__} pid: {os.getpid()}")
32 changes: 0 additions & 32 deletions ditorch/common_adapter.py

This file was deleted.

5 changes: 5 additions & 0 deletions ditorch/common_adapter/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .common_mock import mock_tensor_device


def mock_common():
mock_tensor_device()
37 changes: 37 additions & 0 deletions ditorch/common_adapter/common_mock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import torch
import os


def mock_tensor_device():
if torch.__version__ >= "2.0.0":
from torch.overrides import TorchFunctionMode, resolve_name

class DeviceMock(TorchFunctionMode):
def __init__(self):
super().__init__()

def __torch_function__(self, func, types, args, kwargs=None):
try:
name = resolve_name(func)
except Exception:
name = None
result = func(*args, **(kwargs or {}))
if name == "torch.Tensor.device.__get__":
if result.type not in ["cpu", "mps", "xpu", "xla", "meta"]:
device_str = "cuda"
if result.index is not None:
device_str += f":{result.index}"
result = torch.device(device_str)
if name == "torch.Tensor.__repr__":
device = args[0].device
if device.type != "cpu":
result = result.replace(device.type, "cuda")

return result

device_mock = DeviceMock()
if os.getenv("DITORCH_SHOW_DEVICE_AS_CUDA", "1") == "1":
device_mock.__enter__()


mock_tensor_device()
Empty file removed ditorch/mock_ops/__init__.py
Empty file.
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import os
import pytest
import torch
import ditorch # noqa: F401
import torch.distributed as dist
import torch.multiprocessing as mp
import unittest

world_size = torch.cuda.device_count()


# 分布式环境的初始化
Expand Down Expand Up @@ -115,29 +117,33 @@ def reduce_test(rank, world_size):
cleanup()


# pytest test cases

def test_all_reduce(world_size=2):
"""pytest wrapper for all_reduce test"""
run_distributed_test(all_reduce_test, world_size)

class TestDist(unittest.TestCase):

def test_reduce_scatter(world_size=2):
"""pytest wrapper for reduce_scatter test"""
run_distributed_test(reduce_scatter_test, world_size)
@unittest.skipIf(world_size < 2, "Communication test requires at least two cards")
def test_all_reduce(self, world_size=world_size):
"""pytest wrapper for all_reduce test"""
run_distributed_test(all_reduce_test, world_size)

@unittest.skipIf(world_size < 2, "Communication test requires at least two cards")
def test_reduce_scatter(self, world_size=world_size):
"""pytest wrapper for reduce_scatter test"""
run_distributed_test(reduce_scatter_test, world_size)

def test_reduce_scatter_tensor(world_size=2):
"""pytest wrapper for reduce_scatter test"""
run_distributed_test(reduce_scatter_tensor_test, world_size)
@unittest.skipIf(world_size < 2, "Communication test requires at least two cards")
def test_reduce_scatter_tensor(self, world_size=world_size):
"""pytest wrapper for reduce_scatter test"""
run_distributed_test(reduce_scatter_tensor_test, world_size)

@unittest.skipIf(world_size < 2, "Communication test requires at least two cards")
def test__reduce_scatter_base(self, world_size=world_size):
"""pytest wrapper for reduce_scatter test"""
run_distributed_test(reduce_scatter_base_test, world_size)

def test__reduce_scatter_base(world_size=2):
"""pytest wrapper for reduce_scatter test"""
run_distributed_test(reduce_scatter_base_test, world_size)
@unittest.skipIf(world_size < 2, "Communication test requires at least two cards")
def test_reduce(self, world_size=world_size):
"""pytest wrapper for reduce test"""
run_distributed_test(reduce_test, world_size)


@pytest.mark.parametrize("world_size", [2])
def test_reduce(world_size):
"""pytest wrapper for reduce test"""
run_distributed_test(reduce_test, world_size)
if __name__ == "__main__":
unittest.main()
30 changes: 15 additions & 15 deletions ditorch/test/summary_test_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,23 +70,23 @@ def write_test_info_to_json(test_infos, pytorch_test_result): # noqa: C901
skipped_test_case = {}
failed_test_case = {}
for info in test_infos:
if info['file'] not in passed_test_case:
passed_test_case[info['file']] = []
if info['file'] not in skipped_test_case:
skipped_test_case[info['file']] = []
if info['file'] not in failed_test_case:
failed_test_case[info['file']] = []
if info["file"] not in passed_test_case:
passed_test_case[info["file"]] = []
if info["file"] not in skipped_test_case:
skipped_test_case[info["file"]] = []
if info["file"] not in failed_test_case:
failed_test_case[info["file"]] = []

case_name = info["classname"] + "." + info["name"]

if info["status"] == "passed":
if case_name not in passed_test_case[info['file']]:
if case_name not in passed_test_case[info["file"]]:
passed_test_case[info["file"]].append(case_name)
elif info["status"] == "skipped":
if case_name not in skipped_test_case[info['file']]:
if case_name not in skipped_test_case[info["file"]]:
skipped_test_case[info["file"]].append(case_name)
elif info["status"] == "error":
if case_name not in failed_test_case[info['file']]:
if case_name not in failed_test_case[info["file"]]:
failed_test_case[info["file"]].append(case_name)

passed_case_file_name = pytorch_test_result + "/passed_test_case.json"
Expand All @@ -113,9 +113,9 @@ def write_test_info_to_json(test_infos, pytorch_test_result): # noqa: C901

for info in test_infos:
case_name = info["classname"] + "." + info["name"]
if info['file'] in all_test_case.keys():
if case_name in all_test_case[info['file']]:
all_test_case[info['file']].remove(case_name)
if info["file"] in all_test_case.keys():
if case_name in all_test_case[info["file"]]:
all_test_case[info["file"]].remove(case_name)
with open(never_device_tested_case_file_name, "w") as f:
f.write(json.dumps(all_test_case))

Expand All @@ -129,9 +129,9 @@ def write_test_info_to_json(test_infos, pytorch_test_result): # noqa: C901

for info in test_infos:
case_name = info["classname"] + "." + info["name"]
if info['file'] in all_test_case.keys():
if case_name in all_test_case[info['file']]:
all_test_case[info['file']].remove(case_name)
if info["file"] in all_test_case.keys():
if case_name in all_test_case[info["file"]]:
all_test_case[info["file"]].remove(case_name)
with open(never_cpu_tested_case_file_name, "w") as f:
f.write(json.dumps(all_test_case))

Expand Down
1 change: 0 additions & 1 deletion ditorch/test/test_mock_npu/requirements.txt

This file was deleted.

3 changes: 0 additions & 3 deletions ditorch/torch_biren_adapter.py

This file was deleted.

10 changes: 10 additions & 0 deletions ditorch/torch_biren_adapter/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) 2024, DeepLink.
import torch_br # noqa: F401


def mock():
from torch_br.contrib import transfer_to_supa # noqa: F401


framework = torch_br
arch = "biren"
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Copyright (c) 2024, DeepLink.
import torch # noqa: F401
import torch_dipu # noqa: F401

framework = torch_dipu
arch = "dipu"
34 changes: 0 additions & 34 deletions ditorch/torch_mlu_adapter.py

This file was deleted.

Loading

0 comments on commit bcb0a6f

Please sign in to comment.