Skip to content

Commit d562172

Browse files
VSJMilewskiwilliamFalcon
authored andcommitted
Allow for multiple example inputs when creating summary (#543)
1 parent b492e2b commit d562172

File tree

1 file changed

+16
-5
lines changed

1 file changed

+16
-5
lines changed

pytorch_lightning/core/memory.py

+16-5
Original file line numberDiff line numberDiff line change
@@ -50,20 +50,31 @@ def get_variable_sizes(self):
5050
input_ = self.model.example_input_array
5151

5252
if self.model.on_gpu:
53-
input_ = input_.cuda(0)
53+
device = next(self.model.parameters()).get_device()
54+
# test if input is a list or a tuple
55+
if isinstance(input_, (list, tuple)):
56+
input_ = [input_i.cuda(device) if torch.is_tensor(input_i) else input_i
57+
for input_i in input_]
58+
else:
59+
input_ = input_.cuda(device)
5460

5561
if self.model.trainer.use_amp:
56-
input_ = input_.half()
62+
# test if it is not a list or a tuple
63+
if isinstance(input_, (list, tuple)):
64+
input_ = [input_i.half() if torch.is_tensor(input_i) else input_i
65+
for input_i in input_]
66+
else:
67+
input_ = input_.half()
5768

5869
with torch.no_grad():
5970

6071
for _, m in mods:
61-
if type(input_) is list or type(input_) is tuple: # pragma: no cover
72+
if isinstance(input_, (list, tuple)): # pragma: no cover
6273
out = m(*input_)
6374
else:
6475
out = m(input_)
6576

66-
if type(input_) is tuple or type(input_) is list: # pragma: no cover
77+
if isinstance(input_, (list, tuple)): # pragma: no cover
6778
in_size = []
6879
for x in input_:
6980
if type(x) is list:
@@ -75,7 +86,7 @@ def get_variable_sizes(self):
7586

7687
in_sizes.append(in_size)
7788

78-
if type(out) is tuple or type(out) is list: # pragma: no cover
89+
if isinstance(out, (list, tuple)): # pragma: no cover
7990
out_size = np.asarray([x.size() for x in out])
8091
else:
8192
out_size = np.array(out.size())

0 commit comments

Comments
 (0)