@@ -100,9 +100,7 @@ def forward(self, tensor):
100
100
sin_inp_y = torch .einsum ("i,j->ij" , pos_y , self .inv_freq )
101
101
emb_x = get_emb (sin_inp_x ).unsqueeze (1 )
102
102
emb_y = get_emb (sin_inp_y )
103
- emb = torch .zeros ((x , y , self .channels * 2 ), device = tensor .device ).type (
104
- tensor .type ()
105
- )
103
+ emb = torch .zeros ((x , y , self .channels * 2 ), device = tensor .device ).type (tensor .type ())
106
104
emb [:, :, : self .channels ] = emb_x
107
105
emb [:, :, self .channels : 2 * self .channels ] = emb_y
108
106
@@ -165,9 +163,7 @@ def forward(self, tensor):
165
163
emb_x = get_emb (sin_inp_x ).unsqueeze (1 ).unsqueeze (1 )
166
164
emb_y = get_emb (sin_inp_y ).unsqueeze (1 )
167
165
emb_z = get_emb (sin_inp_z )
168
- emb = torch .zeros ((x , y , z , self .channels * 3 ), device = tensor .device ).type (
169
- tensor .type ()
170
- )
166
+ emb = torch .zeros ((x , y , z , self .channels * 3 ), device = tensor .device ).type (tensor .type ())
171
167
emb [:, :, :, : self .channels ] = emb_x
172
168
emb [:, :, :, self .channels : 2 * self .channels ] = emb_y
173
169
emb [:, :, :, 2 * self .channels :] = emb_z
0 commit comments