@@ -50,20 +50,31 @@ def get_variable_sizes(self):
50
50
input_ = self .model .example_input_array
51
51
52
52
if self .model .on_gpu :
53
- input_ = input_ .cuda (0 )
53
+ device = next (self .model .parameters ()).get_device ()
54
+ # test if input is a list or a tuple
55
+ if isinstance (input_ , (list , tuple )):
56
+ input_ = [input_i .cuda (device ) if torch .is_tensor (input_i ) else input_i
57
+ for input_i in input_ ]
58
+ else :
59
+ input_ = input_ .cuda (device )
54
60
55
61
if self .model .trainer .use_amp :
56
- input_ = input_ .half ()
62
+ # test if it is not a list or a tuple
63
+ if isinstance (input_ , (list , tuple )):
64
+ input_ = [input_i .half () if torch .is_tensor (input_i ) else input_i
65
+ for input_i in input_ ]
66
+ else :
67
+ input_ = input_ .half ()
57
68
58
69
with torch .no_grad ():
59
70
60
71
for _ , m in mods :
61
- if type (input_ ) is list or type ( input_ ) is tuple : # pragma: no cover
72
+ if isinstance (input_ , ( list , tuple )) : # pragma: no cover
62
73
out = m (* input_ )
63
74
else :
64
75
out = m (input_ )
65
76
66
- if type (input_ ) is tuple or type ( input_ ) is list : # pragma: no cover
77
+ if isinstance (input_ , ( list , tuple )) : # pragma: no cover
67
78
in_size = []
68
79
for x in input_ :
69
80
if type (x ) is list :
@@ -75,7 +86,7 @@ def get_variable_sizes(self):
75
86
76
87
in_sizes .append (in_size )
77
88
78
- if type (out ) is tuple or type ( out ) is list : # pragma: no cover
89
+ if isinstance (out , ( list , tuple )) : # pragma: no cover
79
90
out_size = np .asarray ([x .size () for x in out ])
80
91
else :
81
92
out_size = np .array (out .size ())
0 commit comments