Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions src/lerobot/utils/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import platform
from contextlib import suppress
from queue import Empty
from typing import Any

Expand All @@ -30,10 +32,18 @@ def get_last_item_from_queue(queue: Queue, block=True, timeout: float = 0.1) ->
item = None

# Drain queue and keep only the most recent parameters
try:
while True:
if platform.system() == "Darwin":
# On Mac, avoid using `qsize` due to unreliable implementation.
# There is a comment on `qsize` code in the Python source:
# Raises NotImplementedError on Mac OSX because of broken sem_getvalue()
with suppress(Empty):
item = queue.get_nowait()

return item

# Details about using qsize in https://github.com/huggingface/lerobot/issues/1523
while queue.qsize() > 0:
with suppress(Empty):
item = queue.get_nowait()
except Empty:
pass

return item
48 changes: 30 additions & 18 deletions tests/utils/test_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,16 @@
import time
from queue import Queue

import pytest
from torch.multiprocessing import Queue as TorchMPQueue

from lerobot.utils.queue import get_last_item_from_queue


def test_get_last_item_single_item():
@pytest.mark.parametrize("queue_cls", [Queue, TorchMPQueue])
def test_get_last_item_single_item(queue_cls):
"""Test getting the last item when queue has only one item."""
queue = Queue()
queue = queue_cls()
queue.put("single_item")

result = get_last_item_from_queue(queue)
Expand All @@ -32,9 +36,10 @@ def test_get_last_item_single_item():
assert queue.empty()


def test_get_last_item_multiple_items():
@pytest.mark.parametrize("queue_cls", [Queue, TorchMPQueue])
def test_get_last_item_multiple_items(queue_cls):
"""Test getting the last item when queue has multiple items."""
queue = Queue()
queue = queue_cls()
items = ["first", "second", "third", "fourth", "last"]

for item in items:
Expand All @@ -46,9 +51,10 @@ def test_get_last_item_multiple_items():
assert queue.empty()


def test_get_last_item_different_types():
@pytest.mark.parametrize("queue_cls", [Queue, TorchMPQueue])
def test_get_last_item_different_types(queue_cls):
"""Test with different data types in the queue."""
queue = Queue()
queue = queue_cls()
items = [1, 2.5, "string", {"key": "value"}, [1, 2, 3], ("tuple", "data")]

for item in items:
Expand All @@ -60,9 +66,10 @@ def test_get_last_item_different_types():
assert queue.empty()


def test_get_last_item_maxsize_queue():
@pytest.mark.parametrize("queue_cls", [Queue, TorchMPQueue])
def test_get_last_item_maxsize_queue(queue_cls):
"""Test with a queue that has a maximum size."""
queue = Queue(maxsize=5)
queue = queue_cls(maxsize=5)

# Fill the queue
for i in range(5):
Expand All @@ -77,9 +84,10 @@ def test_get_last_item_maxsize_queue():
assert queue.empty()


def test_get_last_item_with_none_values():
@pytest.mark.parametrize("queue_cls", [Queue, TorchMPQueue])
def test_get_last_item_with_none_values(queue_cls):
"""Test with None values in the queue."""
queue = Queue()
queue = queue_cls()
items = [1, None, 2, None, 3]

for item in items:
Expand All @@ -94,23 +102,26 @@ def test_get_last_item_with_none_values():
assert queue.empty()


def test_get_last_item_blocking_timeout():
@pytest.mark.parametrize("queue_cls", [Queue, TorchMPQueue])
def test_get_last_item_blocking_timeout(queue_cls):
"""Test get_last_item_from_queue returns None on timeout."""
queue = Queue()
queue = queue_cls()
result = get_last_item_from_queue(queue, block=True, timeout=0.1)
assert result is None


def test_get_last_item_non_blocking_empty():
@pytest.mark.parametrize("queue_cls", [Queue, TorchMPQueue])
def test_get_last_item_non_blocking_empty(queue_cls):
"""Test get_last_item_from_queue with block=False on an empty queue returns None."""
queue = Queue()
queue = queue_cls()
result = get_last_item_from_queue(queue, block=False)
assert result is None


def test_get_last_item_non_blocking_success():
@pytest.mark.parametrize("queue_cls", [Queue, TorchMPQueue])
def test_get_last_item_non_blocking_success(queue_cls):
"""Test get_last_item_from_queue with block=False on a non-empty queue."""
queue = Queue()
queue = queue_cls()
items = ["first", "second", "last"]
for item in items:
queue.put(item)
Expand All @@ -123,9 +134,10 @@ def test_get_last_item_non_blocking_success():
assert queue.empty()


def test_get_last_item_blocking_waits_for_item():
@pytest.mark.parametrize("queue_cls", [Queue, TorchMPQueue])
def test_get_last_item_blocking_waits_for_item(queue_cls):
"""Test that get_last_item_from_queue waits for an item if block=True."""
queue = Queue()
queue = queue_cls()
result = []

def producer():
Expand Down
Loading