-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
ccc9ed8
commit bcb0a6f
Showing
24 changed files
with
248 additions
and
218 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
date | ||
find . -name test*.py | xargs -I {} bash -c ' echo "start run {}";date;time python {} && echo "Test {} PASSED\n\n\n" || echo "Test {} FAILED\n\n\n"' 2>&1 | tee test_individual_cases.log | ||
|
||
# Check if any tests failed | ||
if grep -Eq "FAILED" test_individual_cases.log; then | ||
echo "tests failed" | ||
exit 1 | ||
else | ||
echo "all tests passed" | ||
exit 0 | ||
fi |
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,36 +1,33 @@ | ||
# Copyright (c) 2024, DeepLink. | ||
import os | ||
|
||
framework = None | ||
adapter = None | ||
try: | ||
from ditorch import torch_npu_adapter | ||
|
||
framework = "torch_npu:" + torch_npu_adapter.torch_npu.__version__ | ||
except Exception as e: # noqa: F841 | ||
import ditorch.torch_npu_adapter as adapter # noqa: F811 | ||
except ImportError as e: # noqa: F841 | ||
pass | ||
try: | ||
from ditorch import torch_dipu_adapter # noqa: F401 | ||
|
||
framework = "torch_dipu" # torch_dipu has not __version__ attr | ||
|
||
except Exception as e: # noqa: F841 | ||
pass | ||
|
||
try: | ||
from ditorch import torch_mlu_adapter | ||
import ditorch.torch_dipu_adapter as adapter # noqa: F811 | ||
|
||
framework = "torch_mlu:" + torch_mlu_adapter.torch_mlu.__version__ | ||
except Exception as e: # noqa: F841 | ||
except ImportError as e: # noqa: F841 | ||
pass | ||
|
||
try: | ||
from ditorch import torch_biren_adapter | ||
import ditorch.torch_mlu_adapter as adapter # noqa: F811 | ||
except ImportError as e: # noqa: F841 | ||
pass | ||
|
||
framework = "torch_br:" + torch_biren_adapter.torch_br.__version__ | ||
except Exception as e: # noqa: F841 | ||
try: | ||
import ditorch.torch_biren_adapter as adapter # noqa: F811 | ||
except ImportError as e: # noqa: F841 | ||
pass | ||
|
||
|
||
from ditorch import common_adapter # noqa: F401,E402 | ||
|
||
print(f"ditorch.framework: {framework} pid: {os.getpid()}") | ||
if adapter is not None: | ||
adapter.mock() | ||
common_adapter.mock_common() | ||
|
||
print(f"ditorch: {adapter.arch} {adapter.framework.__name__}:{adapter.framework.__version__} pid: {os.getpid()}") |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from .common_mock import mock_tensor_device | ||
|
||
|
||
def mock_common(): | ||
mock_tensor_device() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import torch | ||
import os | ||
|
||
|
||
def mock_tensor_device(): | ||
if torch.__version__ >= "2.0.0": | ||
from torch.overrides import TorchFunctionMode, resolve_name | ||
|
||
class DeviceMock(TorchFunctionMode): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def __torch_function__(self, func, types, args, kwargs=None): | ||
try: | ||
name = resolve_name(func) | ||
except Exception: | ||
name = None | ||
result = func(*args, **(kwargs or {})) | ||
if name == "torch.Tensor.device.__get__": | ||
if result.type not in ["cpu", "mps", "xpu", "xla", "meta"]: | ||
device_str = "cuda" | ||
if result.index is not None: | ||
device_str += f":{result.index}" | ||
result = torch.device(device_str) | ||
if name == "torch.Tensor.__repr__": | ||
device = args[0].device | ||
if device.type != "cpu": | ||
result = result.replace(device.type, "cuda") | ||
|
||
return result | ||
|
||
device_mock = DeviceMock() | ||
if os.getenv("DITORCH_SHOW_DEVICE_AS_CUDA", "1") == "1": | ||
device_mock.__enter__() | ||
|
||
|
||
mock_tensor_device() |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
# Copyright (c) 2024, DeepLink. | ||
import torch_br # noqa: F401 | ||
|
||
|
||
def mock(): | ||
from torch_br.contrib import transfer_to_supa # noqa: F401 | ||
|
||
|
||
framework = torch_br | ||
arch = "biren" |
3 changes: 3 additions & 0 deletions
3
ditorch/torch_dipu_adapter.py → ditorch/torch_dipu_adapter/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,6 @@ | ||
# Copyright (c) 2024, DeepLink. | ||
import torch # noqa: F401 | ||
import torch_dipu # noqa: F401 | ||
|
||
framework = torch_dipu | ||
arch = "dipu" |
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.