@@ -54,14 +54,14 @@ def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
5454 target = target .clone ()
5555
5656 if target .ndim == 1 :
57- target = torch .nn .functional .one_hot (target , num_classes = self .num_classes ).to (dtype = torch . float32 )
57+ target = torch .nn .functional .one_hot (target , num_classes = self .num_classes ).to (dtype = batch . dtype )
5858
5959 if torch .rand (1 ).item () >= self .p :
6060 return batch , target
6161
6262 # It's faster to roll the batch by one instead of shuffling it to create image pairs
6363 batch_rolled = batch .roll (1 , 0 )
64- target_rolled = target .roll (1 )
64+ target_rolled = target .roll (1 , 0 )
6565
6666 # Implemented as on mixup paper, page 3.
6767 lambda_param = float (torch ._sample_dirichlet (torch .tensor ([self .alpha , self .alpha ]))[0 ])
@@ -132,14 +132,14 @@ def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
132132 target = target .clone ()
133133
134134 if target .ndim == 1 :
135- target = torch .nn .functional .one_hot (target , num_classes = self .num_classes ).to (dtype = torch . float32 )
135+ target = torch .nn .functional .one_hot (target , num_classes = self .num_classes ).to (dtype = batch . dtype )
136136
137137 if torch .rand (1 ).item () >= self .p :
138138 return batch , target
139139
140140 # It's faster to roll the batch by one instead of shuffling it to create image pairs
141141 batch_rolled = batch .roll (1 , 0 )
142- target_rolled = target .roll (1 )
142+ target_rolled = target .roll (1 , 0 )
143143
144144 # Implemented as on cutmix paper, page 12 (with minor corrections on typos).
145145 lambda_param = float (torch ._sample_dirichlet (torch .tensor ([self .alpha , self .alpha ]))[0 ])
0 commit comments