diff --git a/tests/processor/test_tokenizer_processor.py b/tests/processor/test_tokenizer_processor.py index 3cac6d6f8a4..300191d8635 100644 --- a/tests/processor/test_tokenizer_processor.py +++ b/tests/processor/test_tokenizer_processor.py @@ -919,32 +919,6 @@ def test_device_detection_from_action(): assert attention_mask.device.type == "cuda" -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@require_package("transformers") -def test_device_detection_from_complementary_data(): - """Test that device is detected from tensors in complementary_data.""" - mock_tokenizer = MockTokenizer(vocab_size=100) - processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=10) - - # Create transition with tensor in complementary_data - transition = create_transition( - observation={"metadata": {"key": "value"}}, # No tensors - complementary_data={ - "task": "comp data test", - "index": torch.tensor([42]).cuda(), # Tensor in complementary_data - }, - ) - - result = processor(transition) - - # Check that tokenized tensors match complementary_data tensor's device - tokens = result[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.tokens"] - attention_mask = result[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.attention_mask"] - - assert tokens.device.type == "cuda" - assert attention_mask.device.type == "cuda" - - @require_package("transformers") def test_device_detection_preserves_dtype(): """Test that device detection doesn't affect dtype of tokenized tensors."""