Skip to content

Commit

Permalink
Refactor RedisLock (PP-1481) (#2285)
Browse files Browse the repository at this point in the history
* Update redis lock

* Refactor locks

* Update task lock

* Code review feedback
  • Loading branch information
jonathangreen authored Feb 18, 2025
1 parent c3da14a commit 1a190eb
Show file tree
Hide file tree
Showing 8 changed files with 62 additions and 61 deletions.
12 changes: 2 additions & 10 deletions src/palace/manager/celery/tasks/marc.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,17 +130,9 @@ def marc_export_collection(
context or {}
)

lock = marc_export_collection_lock(
with marc_export_collection_lock(
task.services.redis.client(), collection_id, delta
)

with lock.lock() as locked:
if not locked:
task.log.info(
f"Skipping collection {collection_id} because another task is already processing it."
)
return

).lock():
with ExitStack() as stack, task.transaction() as session:
files = {
library: stack.enter_context(TemporaryFile())
Expand Down
9 changes: 2 additions & 7 deletions src/palace/manager/celery/tasks/opds_odl.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,14 +238,9 @@ def recalculate_hold_queue_collection(
"""
Recalculate the hold queue for a collection.
"""
lock = _redis_lock_recalculate_holds(task.services.redis.client(), collection_id)
analytics = task.services.analytics.analytics()
with lock.lock() as locked:
if not locked:
task.log.info(
f"Skipping collection {collection_id} because another task holds its lock."
)
return
redis_client = task.services.redis.client()
with _redis_lock_recalculate_holds(redis_client, collection_id).lock():
with task.transaction() as session:
collection = Collection.by_id(session, collection_id)
if collection is None:
Expand Down
16 changes: 3 additions & 13 deletions src/palace/manager/celery/tasks/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,9 @@ def search_reindex(task: Task, offset: int = 0, batch_size: int = 500) -> None:
task will do a batch, then requeue itself until all works have been indexed.
"""
index = task.services.search.index()
redis_client = task.services.redis.client()
task_lock = TaskLock(redis_client, task, lock_name="search_reindex")

with task_lock.lock(
release_on_exit=False, ignored_exceptions=(Retry, Ignore)
) as acquired:
if not acquired:
raise BasePalaceException("Another re-index task is already running.")
task_lock = TaskLock(task, lock_name="search_reindex")

with task_lock.lock(release_on_exit=False, ignored_exceptions=(Retry, Ignore)):
task.log.info(
f"Running search reindex at offset {offset} with batch size {batch_size}."
)
Expand Down Expand Up @@ -134,11 +128,7 @@ def update_read_pointer(task: Task) -> None:
@shared_task(queue=QueueNames.default, bind=True)
def search_indexing(task: Task, batch_size: int = 500) -> None:
redis_client = task.services.redis.client()
task_lock = TaskLock(redis_client, task)
with task_lock.lock() as acquired:
if not acquired:
raise BasePalaceException(f"{task.name} is already running.")

with TaskLock(task).lock():
waiting = WaitingForIndexing(redis_client)
works = waiting.pop(batch_size)

Expand Down
21 changes: 18 additions & 3 deletions src/palace/manager/service/redis/models/lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@ class LockError(BasePalaceException):
pass


class LockValueError(LockError, ValueError):
pass


class LockNotAcquired(LockError):
pass


class BaseRedisLock(ABC):
def __init__(
self,
Expand Down Expand Up @@ -72,20 +80,25 @@ def key(self) -> str:
@contextmanager
def lock(
self,
raise_when_not_acquired: bool = True,
release_on_error: bool = True,
release_on_exit: bool = True,
ignored_exceptions: tuple[type[BaseException], ...] = (),
) -> Generator[bool, None, None]:
"""
Context manager for acquiring and releasing the lock.
:param raise_when_not_acquired: If True, raise an exception if the lock is not acquired.
:param release_on_error: If True, release the lock if an exception occurs.
:param release_on_exit: If True, release the lock when the context manager exits.
:param ignored_exceptions: Exceptions that should not cause the lock to be released.
:return: The result of the lock acquisition. You must check the return value to see if the lock was acquired.
"""
locked = self.acquire()
if raise_when_not_acquired and not locked:
raise LockNotAcquired(f"Lock {self.key} could not be acquired")

exception_occurred = False
try:
yield locked
Expand Down Expand Up @@ -180,7 +193,7 @@ def acquire_blocking(self, timeout: float | int = -1) -> bool:
:return: The result of the lock acquisition. You must check the return value to see if the lock was acquired.
"""
if timeout < 0:
raise LockError("Cannot specify a negative timeout")
raise LockValueError("Cannot specify a negative timeout")

start_time = time.time()
while timeout == 0 or (time.time() - start_time) < timeout:
Expand Down Expand Up @@ -216,19 +229,21 @@ def locked(self, by_us: bool = False) -> bool:
class TaskLock(RedisLock):
def __init__(
self,
redis_client: Redis,
task: Task,
redis_client: Redis | None = None,
lock_name: str | None = None,
lock_timeout: timedelta | None = timedelta(minutes=5),
retry_delay: float = 0.2,
):
random_value = task.request.root_id or task.request.id
if lock_name is None:
if task.name is None:
raise LockError(
raise LockValueError(
"Task.name must not be None if lock_name is not provided."
)
name = ["Task", task.name]
else:
name = [lock_name]
if redis_client is None:
redis_client = task.services.redis.client()
super().__init__(redis_client, name, random_value, lock_timeout, retry_delay)
9 changes: 2 additions & 7 deletions tests/manager/celery/tasks/test_marc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
from palace.manager.celery.tasks.marc import marc_export_collection_lock
from palace.manager.marc.exporter import MarcExporter
from palace.manager.marc.uploader import MarcUploadManager
from palace.manager.service.logging.configuration import LogLevel
from palace.manager.service.redis.models.lock import RedisLock
from palace.manager.service.redis.models.lock import LockNotAcquired, RedisLock
from palace.manager.sqlalchemy.model.collection import Collection
from palace.manager.sqlalchemy.model.marcfile import MarcFile
from palace.manager.sqlalchemy.model.work import Work
Expand Down Expand Up @@ -344,16 +343,12 @@ def test_locked(
redis_fixture: RedisFixture,
marc_exporter_fixture: MarcExporterFixture,
marc_export_collection_fixture: MarcExportCollectionFixture,
caplog: pytest.LogCaptureFixture,
):
caplog.set_level(LogLevel.info)
collection = marc_exporter_fixture.collection1
marc_export_collection_fixture.redis_lock(collection).acquire()
marc_export_collection_fixture.setup_mock_storage()
with patch.object(MarcExporter, "query_works") as query:
with pytest.raises(LockNotAcquired):
marc_export_collection_fixture.export_collection(collection)
query.assert_not_called()
assert "another task is already processing it" in caplog.text


def test_marc_export_cleanup(
Expand Down
8 changes: 3 additions & 5 deletions tests/manager/celery/tasks/test_opds_odl.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
remove_expired_holds_for_collection_task,
)
from palace.manager.service.logging.configuration import LogLevel
from palace.manager.service.redis.models.lock import LockNotAcquired
from palace.manager.sqlalchemy.model.circulationevent import CirculationEvent
from palace.manager.sqlalchemy.model.collection import Collection
from palace.manager.sqlalchemy.model.licensing import License, LicensePool
Expand Down Expand Up @@ -384,19 +385,16 @@ def test_already_running(
celery_fixture: CeleryFixture,
redis_fixture: RedisFixture,
db: DatabaseTransactionFixture,
caplog: pytest.LogCaptureFixture,
):
caplog.set_level(LogLevel.info)

collection = db.collection(protocol=OPDS2WithODLApi)
assert collection.id is not None
lock = _redis_lock_recalculate_holds(redis_fixture.client, collection.id)

# Acquire the lock, to simulate another task already running
lock.acquire()
recalculate_hold_queue_collection.delay(collection.id).wait()

assert "another task holds its lock" in caplog.text
with pytest.raises(LockNotAcquired):
recalculate_hold_queue_collection.delay(collection.id).wait()

def test_collection_deleted(
self,
Expand Down
14 changes: 6 additions & 8 deletions tests/manager/celery/tasks/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from palace.manager.core.exceptions import BasePalaceException
from palace.manager.scripts.initialization import InstanceInitializationScript
from palace.manager.search.external_search import Filter
from palace.manager.service.redis.models.lock import TaskLock
from palace.manager.service.redis.models.lock import LockNotAcquired, TaskLock
from palace.manager.service.redis.models.search import WaitingForIndexing
from tests.fixtures.celery import CeleryFixture
from tests.fixtures.database import DatabaseTransactionFixture
Expand All @@ -36,7 +36,7 @@ def __init__(self, redis_fixture: RedisFixture):
self.task = MagicMock()
self.task.request.root_id = "fake"
self.task_lock = TaskLock(
self.redis_client, self.task, lock_name="search_reindex"
self.task, lock_name="search_reindex", redis_client=self.redis_client
)


Expand Down Expand Up @@ -112,10 +112,10 @@ def test_search_reindex_lock(
):
search_reindex_task_lock_fixture.task_lock.acquire()

with pytest.raises(BasePalaceException) as exc_info:
with pytest.raises(LockNotAcquired) as exc_info:
search_reindex.delay().wait()

assert "Another re-index task is already running." in str(exc_info.value)
assert "TaskLock::search_reindex could not be acquired" in str(exc_info.value)


def test_fiction_query_returns_results(
Expand Down Expand Up @@ -351,7 +351,7 @@ def __init__(self, redis_fixture: RedisFixture):
task = MagicMock()
task.request.root_id = "fake"
task.name = "palace.manager.celery.tasks.search.search_indexing"
self.lock = TaskLock(self.redis_client, task)
self.lock = TaskLock(task, redis_client=self.redis_client)

self.waiting = WaitingForIndexing(self.redis_client)
self.mock_works = {w_id for w_id in range(10)}
Expand All @@ -371,11 +371,9 @@ def test_search_indexing_lock(
):
search_indexing_fixture.lock.acquire()

with pytest.raises(BasePalaceException) as exc_info:
with pytest.raises(LockNotAcquired):
search_indexing.delay().wait()

assert "search_indexing is already running." in str(exc_info.value)


@pytest.mark.parametrize("batch_size", [3, 5, 500])
@patch("palace.manager.celery.tasks.search.index_works")
Expand Down
34 changes: 26 additions & 8 deletions tests/manager/service/redis/models/test_lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@
import pytest

from palace.manager.celery.task import Task
from palace.manager.service.redis.models.lock import LockError, RedisLock, TaskLock
from palace.manager.service.redis.models.lock import (
LockNotAcquired,
LockValueError,
RedisLock,
TaskLock,
)
from tests.fixtures.redis import RedisFixture


Expand Down Expand Up @@ -47,7 +52,7 @@ def test_acquire(

def test_acquire_blocking(self, redis_lock_fixture: RedisLockFixture):
# If you specify a negative timeout, you should get an error
with pytest.raises(LockError):
with pytest.raises(LockValueError):
redis_lock_fixture.lock.acquire_blocking(timeout=-5)

# If you acquire the lock with blocking, it will block until the lock is available or times out.
Expand Down Expand Up @@ -114,11 +119,22 @@ def test_lock(self, redis_lock_fixture: RedisLockFixture):
assert redis_lock_fixture.lock.locked() is True
assert redis_lock_fixture.lock.locked() is False

# The context manager returns LockReturn.acquired if the lock is acquired
# The context manager returns false if the lock is not acquired and the
# raise_when_not_acquired parameter is set to False
with redis_lock_fixture.no_timeout_lock.lock():
with redis_lock_fixture.lock.lock() as acquired:
with redis_lock_fixture.lock.lock(
raise_when_not_acquired=False
) as acquired:
assert not acquired

# If the raise_when_not_acquired parameter is set (the default), the context manager raises
# an exception if the lock is not acquired.
with redis_lock_fixture.no_timeout_lock.lock() as acquired:
assert acquired
with pytest.raises(LockNotAcquired):
with redis_lock_fixture.lock.lock():
...

# If the lock is extended, the context manager returns True
redis_lock_fixture.lock.acquire()
with redis_lock_fixture.lock.lock() as acquired:
Expand Down Expand Up @@ -171,14 +187,16 @@ def test___init__(self, redis_fixture: RedisFixture):
mock_task.name = None

# If we don't provide a lock_name, and the task name is None, we should get an error
with pytest.raises(LockError):
TaskLock(redis_fixture.client, mock_task)
with pytest.raises(LockValueError):
TaskLock(mock_task, redis_client=redis_fixture.client)

# If we don't provide a lock_name, we should use the task name
mock_task.name = "test_task"
task_lock = TaskLock(redis_fixture.client, mock_task)
task_lock = TaskLock(mock_task, redis_client=redis_fixture.client)
assert task_lock.key.endswith("::TaskLock::Task::test_task")

# If we provide a lock_name, we should use that instead
task_lock = TaskLock(redis_fixture.client, mock_task, lock_name="test_lock")
task_lock = TaskLock(
mock_task, lock_name="test_lock", redis_client=redis_fixture.client
)
assert task_lock.key.endswith("::TaskLock::test_lock")

0 comments on commit 1a190eb

Please sign in to comment.