Skip to content

Commit 93f9fde

Browse files
committed
Add a pytorch tensor representation type
Signed-off-by: Jibin Varghese <[email protected]>
1 parent 3030fa9 commit 93f9fde

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from _pydevd_bundle.pydevd_extension_api import StrPresentationProvider
2+
from .pydevd_helpers import find_mod_attr, find_class_name
3+
4+
5+
class PyTorchTensorFormStr:
6+
def can_provide(self, type_object, type_name):
7+
torch_tensor_class = find_mod_attr('torch', 'Tensor')
8+
return torch_tensor_class is not None and issubclass(type_object, torch_tensor_class)
9+
10+
def get_str(self, val):
11+
if hasattr(val, 'shape') and hasattr(val, 'device'):
12+
return "%s [shape: %s, device: %s]: %r" % (find_class_name(val), str(list(val.shape)), val.device, val)
13+
else:
14+
return "%s: %r" % (find_class_name(val), val)
15+
16+
import sys
17+
18+
if not sys.platform.startswith("java"):
19+
StrPresentationProvider.register(PyTorchTensorFormStr)

0 commit comments

Comments
 (0)