diff --git a/tests/processor/test_batch_processor.py b/tests/processor/test_batch_processor.py index c9c4cd1dd7..0bf050e20c 100644 --- a/tests/processor/test_batch_processor.py +++ b/tests/processor/test_batch_processor.py @@ -603,24 +603,6 @@ def test_action_dtype_preservation(): assert result[TransitionKey.ACTION].shape == (1, 4) -def test_action_in_place_mutation(): - """Test that the processor mutates the transition in place for actions.""" - processor = ToBatchProcessor() - - action = torch.randn(4) - transition = create_transition(action=action) - - # Store reference to original transition - original_transition = transition - - # Process - result = processor(transition) - - # Should be the same object (in-place mutation) - assert result is original_transition - assert result[TransitionKey.ACTION].shape == (1, 4) - - def test_empty_action_tensor(): """Test handling of empty action tensors.""" processor = ToBatchProcessor() @@ -851,27 +833,6 @@ def test_task_comprehensive_string_cases(): processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] assert processed_comp_data["task"] == task_list assert isinstance(processed_comp_data["task"], list) - assert processed_comp_data["task"] is task_list # Should be same object (in-place) - - -def test_task_in_place_mutation(): - """Test that the processor mutates complementary_data in place for tasks.""" - processor = ToBatchProcessor() - - complementary_data = {"task": "sort_objects"} - transition = create_transition(complementary_data=complementary_data) - - # Store reference to original transition and complementary_data - original_transition = transition - original_comp_data = complementary_data - - # Process - result = processor(transition) - - # Should be the same objects (in-place mutation) - assert result is original_transition - assert result[TransitionKey.COMPLEMENTARY_DATA] is original_comp_data - assert original_comp_data["task"] == ["sort_objects"] def test_task_preserves_other_keys(): @@ -1127,3 +1088,49 @@ def test_empty_index_tensor(): # Should remain unchanged (already 1D) assert result[TransitionKey.COMPLEMENTARY_DATA]["index"].shape == (0,) + + +def test_action_processing_creates_new_transition(): + """Test that the processor creates a new transition object with correctly processed action.""" + processor = ToBatchProcessor() + + action = torch.randn(4) + transition = create_transition(action=action) + + # Store reference to original transition + original_transition = transition + + # Process + result = processor(transition) + + # Should be a different object (functional design, not in-place mutation) + assert result is not original_transition + # Original transition should remain unchanged + assert original_transition[TransitionKey.ACTION].shape == (4,) + # Result should have correctly processed action with batch dimension + assert result[TransitionKey.ACTION].shape == (1, 4) + assert torch.equal(result[TransitionKey.ACTION][0], action) + + +def test_task_processing_creates_new_transition(): + """Test that the processor creates a new transition object with correctly processed task.""" + processor = ToBatchProcessor() + + complementary_data = {"task": "sort_objects"} + transition = create_transition(complementary_data=complementary_data) + + # Store reference to original transition and complementary_data + original_transition = transition + original_comp_data = complementary_data + + # Process + result = processor(transition) + + # Should be different transition object (functional design) + assert result is not original_transition + # But complementary_data is the same reference (current implementation behavior) + assert result[TransitionKey.COMPLEMENTARY_DATA] is original_comp_data + # The task should be processed correctly (wrapped in list) + assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == ["sort_objects"] + # Original complementary data is also modified (current behavior) + assert original_comp_data["task"] == ["sort_objects"]