diff --git a/tests/influence/_utils/common.py b/tests/influence/_utils/common.py index b8cd10ba30..5d549638c9 100644 --- a/tests/influence/_utils/common.py +++ b/tests/influence/_utils/common.py @@ -285,21 +285,17 @@ def get_random_model_and_data( torch.normal(0, 1, (num_samples, in_features)).double() for _ in range(num_inputs) ] - all_samples = ( - _move_sample_to_cuda(all_samples) - if isinstance(all_samples, list) and use_gpu - else (all_samples.cuda() if use_gpu else all_samples) - ) + if use_gpu: + all_samples = _move_sample_to_cuda(all_samples) + train_samples = [ts[:num_train] for ts in all_samples] test_samples = [ts[num_train:] for ts in all_samples] hessian_samples = [ts[:num_hessian] for ts in all_samples] else: all_samples = torch.normal(0, 1, (num_samples, in_features)).double() - all_samples = ( - _move_sample_to_cuda(all_samples) - if isinstance(all_samples, list) and use_gpu - else (all_samples.cuda() if use_gpu else all_samples) - ) + + if use_gpu: + all_samples = all_samples.cuda() train_samples = all_samples[:num_train] test_samples = all_samples[num_train:] hessian_samples = all_samples[:num_hessian]