Skip to content

Commit

Permalink
Improve typing of default_task_backend to fix mypy issues
Browse files Browse the repository at this point in the history
`mypy` issues were caused by typeddjango/django-stubs#2311, but this improves the types further.
  • Loading branch information
RealOrangeOne committed Aug 2, 2024
1 parent 39ceb97 commit 63fcda1
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 38 deletions.
6 changes: 4 additions & 2 deletions django_tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
django_stubs_ext.monkeypatch()

import importlib.metadata
from typing import Mapping, Optional, cast
from typing import Optional

from django.utils.connection import BaseConnectionHandler, ConnectionProxy
from django.utils.module_loading import import_string
Expand Down Expand Up @@ -67,4 +67,6 @@ def create_connection(self, alias: str) -> BaseTaskBackend:

tasks = TasksHandler()

default_task_backend = ConnectionProxy(cast(Mapping, tasks), DEFAULT_TASK_BACKEND_ALIAS)
default_task_backend: BaseTaskBackend = ConnectionProxy( # type:ignore[assignment]
tasks, DEFAULT_TASK_BACKEND_ALIAS
)
20 changes: 10 additions & 10 deletions tests/tests/test_database_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from django.urls import reverse
from django.utils import timezone

from django_tasks import ResultStatus, default_task_backend, tasks
from django_tasks import ResultStatus, Task, default_task_backend, tasks
from django_tasks.backends.database import DatabaseBackend
from django_tasks.backends.database.management.commands.db_worker import (
logger as db_worker_logger,
Expand All @@ -43,7 +43,7 @@ def test_using_correct_backend(self) -> None:
def test_enqueue_task(self) -> None:
for task in [test_tasks.noop_task, test_tasks.noop_task_async]:
with self.subTest(task), self.assertNumQueries(1):
result = default_task_backend.enqueue(task, (1,), {"two": 3})
result = cast(Task, task).enqueue(1, two=3)

self.assertEqual(result.status, ResultStatus.NEW)
self.assertIsNone(result.started_at)
Expand All @@ -58,7 +58,7 @@ def test_enqueue_task(self) -> None:
async def test_enqueue_task_async(self) -> None:
for task in [test_tasks.noop_task, test_tasks.noop_task_async]:
with self.subTest(task):
result = await default_task_backend.aenqueue(task, [], {})
result = await cast(Task, task).aenqueue()

self.assertEqual(result.status, ResultStatus.NEW)
self.assertIsNone(result.started_at)
Expand Down Expand Up @@ -127,11 +127,11 @@ async def test_refresh_result_async(self) -> None:

def test_get_missing_result(self) -> None:
with self.assertRaises(ResultDoesNotExist):
default_task_backend.get_result(uuid.uuid4())
default_task_backend.get_result(str(uuid.uuid4()))

async def test_async_get_missing_result(self) -> None:
with self.assertRaises(ResultDoesNotExist):
await default_task_backend.aget_result(uuid.uuid4())
await default_task_backend.aget_result(str(uuid.uuid4()))

def test_invalid_uuid(self) -> None:
with self.assertRaises(ResultDoesNotExist):
Expand Down Expand Up @@ -208,7 +208,7 @@ def test_database_backend_app_missing(self) -> None:
errors = list(default_task_backend.check())

self.assertEqual(len(errors), 1)
self.assertIn("django_tasks.backends.database", errors[0].hint)
self.assertIn("django_tasks.backends.database", errors[0].hint) # type:ignore[arg-type]

def test_priority_range_check(self) -> None:
with self.assertRaises(IntegrityError):
Expand Down Expand Up @@ -262,7 +262,7 @@ def test_run_enqueued_task(self) -> None:
test_tasks.noop_task_async,
]:
with self.subTest(task):
result = default_task_backend.enqueue(task, [], {})
result = cast(Task, task).enqueue()
self.assertEqual(DBTaskResult.objects.ready().count(), 1)

self.assertEqual(result.status, ResultStatus.NEW)
Expand All @@ -274,8 +274,8 @@ def test_run_enqueued_task(self) -> None:
result.refresh()
self.assertIsNotNone(result.started_at)
self.assertIsNotNone(result.finished_at)
self.assertGreaterEqual(result.started_at, result.enqueued_at)
self.assertGreaterEqual(result.finished_at, result.started_at)
self.assertGreaterEqual(result.started_at, result.enqueued_at) # type:ignore[arg-type]
self.assertGreaterEqual(result.finished_at, result.started_at) # type:ignore[arg-type,misc]
self.assertEqual(result.status, ResultStatus.COMPLETE)

self.assertEqual(DBTaskResult.objects.ready().count(), 0)
Expand Down Expand Up @@ -777,7 +777,7 @@ def test_get_locked_with_locked_rows(self) -> None:
normalize_uuid(result_2.id),
)
self.assertEqual(
normalize_uuid(DBTaskResult.objects.get_locked().id), # type:ignore[union-attr, arg-type]
normalize_uuid(DBTaskResult.objects.get_locked().id), # type:ignore[union-attr]
normalize_uuid(result_2.id),
)
finally:
Expand Down
19 changes: 11 additions & 8 deletions tests/tests/test_dummy_backend.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import json
from typing import cast

from django.test import SimpleTestCase, override_settings
from django.urls import reverse

from django_tasks import ResultStatus, default_task_backend, tasks
from django_tasks import ResultStatus, Task, default_task_backend, tasks
from django_tasks.backends.dummy import DummyBackend
from django_tasks.exceptions import ResultDoesNotExist
from tests import tasks as test_tasks
Expand All @@ -14,7 +15,7 @@
)
class DummyBackendTestCase(SimpleTestCase):
def setUp(self) -> None:
default_task_backend.clear()
default_task_backend.clear() # type:ignore[attr-defined]

def test_using_correct_backend(self) -> None:
self.assertEqual(default_task_backend, tasks["default"])
Expand All @@ -23,7 +24,7 @@ def test_using_correct_backend(self) -> None:
def test_enqueue_task(self) -> None:
for task in [test_tasks.noop_task, test_tasks.noop_task_async]:
with self.subTest(task):
result = default_task_backend.enqueue(task, (1,), {"two": 3})
result = cast(Task, task).enqueue(1, two=3)

self.assertEqual(result.status, ResultStatus.NEW)
self.assertIsNone(result.started_at)
Expand All @@ -34,12 +35,12 @@ def test_enqueue_task(self) -> None:
self.assertEqual(result.args, [1])
self.assertEqual(result.kwargs, {"two": 3})

self.assertIn(result, default_task_backend.results)
self.assertIn(result, default_task_backend.results) # type:ignore[attr-defined]

async def test_enqueue_task_async(self) -> None:
for task in [test_tasks.noop_task, test_tasks.noop_task_async]:
with self.subTest(task):
result = await default_task_backend.aenqueue(task, (), {})
result = await cast(Task, task).aenqueue()

self.assertEqual(result.status, ResultStatus.NEW)
self.assertIsNone(result.started_at)
Expand All @@ -50,7 +51,7 @@ async def test_enqueue_task_async(self) -> None:
self.assertEqual(result.args, [])
self.assertEqual(result.kwargs, {})

self.assertIn(result, default_task_backend.results)
self.assertIn(result, default_task_backend.results) # type:ignore[attr-defined]

def test_get_result(self) -> None:
result = default_task_backend.enqueue(test_tasks.noop_task, (), {})
Expand All @@ -71,7 +72,8 @@ def test_refresh_result(self) -> None:
test_tasks.calculate_meaning_of_life, (), {}
)

default_task_backend.results[0].status = ResultStatus.COMPLETE
enqueued_result = default_task_backend.results[0] # type:ignore[attr-defined]
enqueued_result.status = ResultStatus.COMPLETE

self.assertEqual(result.status, ResultStatus.NEW)
result.refresh()
Expand All @@ -82,7 +84,8 @@ async def test_refresh_result_async(self) -> None:
test_tasks.calculate_meaning_of_life, (), {}
)

default_task_backend.results[0].status = ResultStatus.COMPLETE
enqueued_result = default_task_backend.results[0] # type:ignore[attr-defined]
enqueued_result.status = ResultStatus.COMPLETE

self.assertEqual(result.status, ResultStatus.NEW)
await result.arefresh()
Expand Down
32 changes: 17 additions & 15 deletions tests/tests/test_immediate_backend.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import json
from typing import cast

from django.test import SimpleTestCase, override_settings
from django.urls import reverse
from django.utils import timezone

from django_tasks import ResultStatus, default_task_backend, tasks
from django_tasks import ResultStatus, Task, default_task_backend, tasks
from django_tasks.backends.immediate import ImmediateBackend
from django_tasks.exceptions import InvalidTaskError
from tests import tasks as test_tasks
Expand All @@ -21,13 +22,13 @@ def test_using_correct_backend(self) -> None:
def test_enqueue_task(self) -> None:
for task in [test_tasks.noop_task, test_tasks.noop_task_async]:
with self.subTest(task):
result = default_task_backend.enqueue(task, (1,), {"two": 3})
result = cast(Task, task).enqueue(1, two=3)

self.assertEqual(result.status, ResultStatus.COMPLETE)
self.assertIsNotNone(result.started_at)
self.assertIsNotNone(result.finished_at)
self.assertGreaterEqual(result.started_at, result.enqueued_at)
self.assertGreaterEqual(result.finished_at, result.started_at)
self.assertGreaterEqual(result.started_at, result.enqueued_at) # type:ignore[arg-type]
self.assertGreaterEqual(result.finished_at, result.started_at) # type:ignore[arg-type, misc]
self.assertIsNone(result.result)
self.assertEqual(result.task, task)
self.assertEqual(result.args, [1])
Expand All @@ -36,13 +37,13 @@ def test_enqueue_task(self) -> None:
async def test_enqueue_task_async(self) -> None:
for task in [test_tasks.noop_task, test_tasks.noop_task_async]:
with self.subTest(task):
result = await default_task_backend.aenqueue(task, (), {})
result = await cast(Task, task).aenqueue()

self.assertEqual(result.status, ResultStatus.COMPLETE)
self.assertIsNotNone(result.started_at)
self.assertIsNotNone(result.finished_at)
self.assertGreaterEqual(result.started_at, result.enqueued_at)
self.assertGreaterEqual(result.finished_at, result.started_at)
self.assertGreaterEqual(result.started_at, result.enqueued_at) # type:ignore[arg-type]
self.assertGreaterEqual(result.finished_at, result.started_at) # type:ignore[arg-type, misc]
self.assertIsNone(result.result)
self.assertIsNone(result.get_result())
self.assertEqual(result.task, task)
Expand All @@ -66,7 +67,7 @@ def test_catches_exception(self) -> None:
with self.subTest(task), self.assertLogs(
"django_tasks.backends.immediate", level="ERROR"
) as captured_logs:
result = default_task_backend.enqueue(task, [], {})
result = task.enqueue()

# assert logging
self.assertEqual(len(captured_logs.output), 1)
Expand All @@ -76,11 +77,12 @@ def test_catches_exception(self) -> None:
self.assertEqual(result.status, ResultStatus.FAILED)
self.assertIsNotNone(result.started_at)
self.assertIsNotNone(result.finished_at)
self.assertGreaterEqual(result.started_at, result.enqueued_at)
self.assertGreaterEqual(result.finished_at, result.started_at)
self.assertGreaterEqual(result.started_at, result.enqueued_at) # type:ignore[arg-type]
self.assertGreaterEqual(result.finished_at, result.started_at) # type:ignore[arg-type, misc]
self.assertIsInstance(result.result, exception)
self.assertTrue(
result.traceback.endswith(f"{exception.__name__}: {message}\n")
result.traceback
and result.traceback.endswith(f"{exception.__name__}: {message}\n")
)
self.assertIsNone(result.get_result())
self.assertEqual(result.task, task)
Expand All @@ -104,13 +106,13 @@ def test_throws_keyboard_interrupt(self) -> None:

def test_complex_exception(self) -> None:
with self.assertLogs("django_tasks.backends.immediate", level="ERROR"):
result = default_task_backend.enqueue(test_tasks.complex_exception, [], {})
result = test_tasks.complex_exception.enqueue()

self.assertEqual(result.status, ResultStatus.FAILED)
self.assertIsNotNone(result.started_at)
self.assertIsNotNone(result.finished_at)
self.assertGreaterEqual(result.started_at, result.enqueued_at)
self.assertGreaterEqual(result.finished_at, result.started_at)
self.assertGreaterEqual(result.started_at, result.enqueued_at) # type:ignore[arg-type]
self.assertGreaterEqual(result.finished_at, result.started_at) # type:ignore[arg-type,misc]

self.assertIsNone(result.result)
self.assertIsNone(result.traceback)
Expand Down Expand Up @@ -147,7 +149,7 @@ async def test_cannot_get_result(self) -> None:
NotImplementedError,
"This backend does not support retrieving or refreshing results.",
):
await default_task_backend.get_result(123)
await default_task_backend.aget_result(123) # type:ignore[arg-type]

async def test_cannot_refresh_result(self) -> None:
result = default_task_backend.enqueue(
Expand Down
6 changes: 3 additions & 3 deletions tests/tests/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
)
class TaskTestCase(SimpleTestCase):
def setUp(self) -> None:
default_task_backend.clear()
default_task_backend.clear() # type:ignore[attr-defined]

def test_using_correct_backend(self) -> None:
self.assertEqual(default_task_backend, tasks["default"])
Expand All @@ -55,7 +55,7 @@ def test_enqueue_task(self) -> None:
self.assertEqual(result.args, [])
self.assertEqual(result.kwargs, {})

self.assertEqual(default_task_backend.results, [result])
self.assertEqual(default_task_backend.results, [result]) # type:ignore[attr-defined]

async def test_enqueue_task_async(self) -> None:
result = await test_tasks.noop_task.aenqueue()
Expand All @@ -65,7 +65,7 @@ async def test_enqueue_task_async(self) -> None:
self.assertEqual(result.args, [])
self.assertEqual(result.kwargs, {})

self.assertEqual(default_task_backend.results, [result])
self.assertEqual(default_task_backend.results, [result]) # type:ignore[attr-defined]

def test_using_priority(self) -> None:
self.assertEqual(test_tasks.noop_task.priority, 0)
Expand Down

0 comments on commit 63fcda1

Please sign in to comment.