Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 46 additions & 39 deletions tests/processor/test_batch_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,24 +603,6 @@ def test_action_dtype_preservation():
assert result[TransitionKey.ACTION].shape == (1, 4)


def test_action_in_place_mutation():
"""Test that the processor mutates the transition in place for actions."""
processor = ToBatchProcessor()

action = torch.randn(4)
transition = create_transition(action=action)

# Store reference to original transition
original_transition = transition

# Process
result = processor(transition)

# Should be the same object (in-place mutation)
assert result is original_transition
assert result[TransitionKey.ACTION].shape == (1, 4)


def test_empty_action_tensor():
"""Test handling of empty action tensors."""
processor = ToBatchProcessor()
Expand Down Expand Up @@ -851,27 +833,6 @@ def test_task_comprehensive_string_cases():
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
assert processed_comp_data["task"] == task_list
assert isinstance(processed_comp_data["task"], list)
assert processed_comp_data["task"] is task_list # Should be same object (in-place)


def test_task_in_place_mutation():
"""Test that the processor mutates complementary_data in place for tasks."""
processor = ToBatchProcessor()

complementary_data = {"task": "sort_objects"}
transition = create_transition(complementary_data=complementary_data)

# Store reference to original transition and complementary_data
original_transition = transition
original_comp_data = complementary_data

# Process
result = processor(transition)

# Should be the same objects (in-place mutation)
assert result is original_transition
assert result[TransitionKey.COMPLEMENTARY_DATA] is original_comp_data
assert original_comp_data["task"] == ["sort_objects"]


def test_task_preserves_other_keys():
Expand Down Expand Up @@ -1127,3 +1088,49 @@ def test_empty_index_tensor():

# Should remain unchanged (already 1D)
assert result[TransitionKey.COMPLEMENTARY_DATA]["index"].shape == (0,)


def test_action_processing_creates_new_transition():
"""Test that the processor creates a new transition object with correctly processed action."""
processor = ToBatchProcessor()

action = torch.randn(4)
transition = create_transition(action=action)

# Store reference to original transition
original_transition = transition

# Process
result = processor(transition)

# Should be a different object (functional design, not in-place mutation)
assert result is not original_transition
# Original transition should remain unchanged
assert original_transition[TransitionKey.ACTION].shape == (4,)
# Result should have correctly processed action with batch dimension
assert result[TransitionKey.ACTION].shape == (1, 4)
assert torch.equal(result[TransitionKey.ACTION][0], action)


def test_task_processing_creates_new_transition():
"""Test that the processor creates a new transition object with correctly processed task."""
processor = ToBatchProcessor()

complementary_data = {"task": "sort_objects"}
transition = create_transition(complementary_data=complementary_data)

# Store reference to original transition and complementary_data
original_transition = transition
original_comp_data = complementary_data

# Process
result = processor(transition)

# Should be different transition object (functional design)
assert result is not original_transition
# But complementary_data is the same reference (current implementation behavior)
assert result[TransitionKey.COMPLEMENTARY_DATA] is original_comp_data
# The task should be processed correctly (wrapped in list)
assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == ["sort_objects"]
# Original complementary data is also modified (current behavior)
assert original_comp_data["task"] == ["sort_objects"]