diff --git a/ditorch/torch_npu_adapter.py b/ditorch/torch_npu_adapter.py index 986baa6..5b08d8f 100644 --- a/ditorch/torch_npu_adapter.py +++ b/ditorch/torch_npu_adapter.py @@ -1,3 +1,17 @@ # Copyright (c) 2024, DeepLink. +import torch import torch_npu # noqa: F401 from torch_npu.contrib import transfer_to_npu # noqa: F401 + + +def current_stream(device=None): + old_device = torch.cuda.current_device() + if device is None: + device = old_device + torch.cuda.set_device(device) + stream = torch_npu.npu.current_stream(device) + torch.cuda.set_device(old_device) + return stream + + +torch.cuda.current_stream = current_stream diff --git a/op_tools/test/test_current_stream.py b/op_tools/test/test_current_stream.py new file mode 100644 index 0000000..58a04b0 --- /dev/null +++ b/op_tools/test/test_current_stream.py @@ -0,0 +1,18 @@ +import ditorch +import torch +import unittest + + +class TestCurrentStream(unittest.TestCase): + def test_current_stream(self): + stream = torch.cuda.current_stream() + self.assertIsInstance(stream, torch.cuda.Stream) + + def test_current_stream_device(self): + for device in range(torch.cuda.device_count()): + stream = torch.cuda.current_stream(device) + self.assertIsInstance(stream, torch.cuda.Stream) + + +if __name__ == "__main__": + unittest.main()