diff --git a/oauth2_provider/management/commands/cleartokens.py b/oauth2_provider/management/commands/cleartokens.py index 3fb1827f6..9d58361bc 100644 --- a/oauth2_provider/management/commands/cleartokens.py +++ b/oauth2_provider/management/commands/cleartokens.py @@ -3,7 +3,7 @@ from ...models import clear_expired -class Command(BaseCommand): +class Command(BaseCommand): # pragma: no cover help = "Can be run as a cronjob or directly to clean out expired tokens" def handle(self, *args, **options): diff --git a/tests/test_models.py b/tests/test_models.py index 7b37486ca..9ce1e5eb7 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,3 +1,5 @@ +from datetime import timedelta + import pytest from django.contrib.auth import get_user_model from django.core.exceptions import ImproperlyConfigured, ValidationError @@ -294,7 +296,11 @@ def test_str(self): class TestClearExpired(BaseTestModels): def setUp(self): super().setUp() - # Insert two tokens on database. + # Insert many tokens, both expired and not, and grants. + self.num_tokens = 100 + now = timezone.now() + earlier = now - timedelta(seconds=100) + later = now + timedelta(seconds=100) app = Application.objects.create( name="test_app", redirect_uris="http://localhost http://example.com http://example.org", @@ -302,23 +308,54 @@ def setUp(self): client_type=Application.CLIENT_CONFIDENTIAL, authorization_grant_type=Application.GRANT_AUTHORIZATION_CODE, ) - AccessToken.objects.create( - token="555", - expires=timezone.now(), - scope=2, - application=app, - user=self.user, - created=timezone.now(), - updated=timezone.now(), + # make 200 access tokens, half current and half expired. + expired_access_tokens = AccessToken.objects.bulk_create( + AccessToken(token="expired AccessToken {}".format(i), expires=earlier) + for i in range(self.num_tokens) ) - AccessToken.objects.create( - token="666", - expires=timezone.now(), - scope=2, - application=app, - user=self.user, - created=timezone.now(), - updated=timezone.now(), + current_access_tokens = AccessToken.objects.bulk_create( + AccessToken(token=f"current AccessToken {i}", expires=later) for i in range(self.num_tokens) + ) + # Give the first half of the access tokens a refresh token, + # alternating between current and expired ones. + RefreshToken.objects.bulk_create( + RefreshToken( + token=f"expired AT's refresh token {i}", + application=app, + access_token=expired_access_tokens[i].pk, + user=self.user, + ) + for i in range(0, len(expired_access_tokens) // 2, 2) + ) + RefreshToken.objects.bulk_create( + RefreshToken( + token=f"current AT's refresh token {i}", + application=app, + access_token=current_access_tokens[i].pk, + user=self.user, + ) + for i in range(1, len(current_access_tokens) // 2, 2) + ) + # Make some grants, half of which are expired. + Grant.objects.bulk_create( + Grant( + user=self.user, + code=f"old grant code {i}", + application=app, + expires=earlier, + redirect_uri="https://localhost/redirect", + ) + for i in range(self.num_tokens) + ) + Grant.objects.bulk_create( + Grant( + user=self.user, + code=f"new grant code {i}", + application=app, + expires=later, + redirect_uri="https://localhost/redirect", + ) + for i in range(self.num_tokens) ) def test_clear_expired_tokens(self): @@ -333,15 +370,21 @@ def test_clear_expired_tokens_incorect_timetype(self): assert result == "ImproperlyConfigured" def test_clear_expired_tokens_with_tokens(self): - self.client.login(username="test_user", password="123456") - self.oauth2_settings.REFRESH_TOKEN_EXPIRE_SECONDS = 0 - ttokens = AccessToken.objects.count() - expiredt = AccessToken.objects.filter(expires__lte=timezone.now()).count() - assert ttokens == 2 - assert expiredt == 2 + self.oauth2_settings.CLEAR_EXPIRED_TOKENS_BATCH_SIZE = 10 + self.oauth2_settings.CLEAR_EXPIRED_TOKENS_BATCH_INTERVAL = 0.0 + at_count = AccessToken.objects.count() + assert at_count == 2 * self.num_tokens, f"{2 * self.num_tokens} access tokens should exist." + rt_count = RefreshToken.objects.count() + assert rt_count == self.num_tokens // 2, f"{self.num_tokens // 2} refresh tokens should exist." + gt_count = Grant.objects.count() + assert gt_count == self.num_tokens * 2, f"{self.num_tokens * 2} grants should exist." clear_expired() - expiredt = AccessToken.objects.filter(expires__lte=timezone.now()).count() - assert expiredt == 0 + at_count = AccessToken.objects.count() + assert at_count == self.num_tokens, "Half the access tokens should not have been deleted." + rt_count = RefreshToken.objects.count() + assert rt_count == self.num_tokens // 2, "Half of the refresh tokens should have been deleted." + gt_count = Grant.objects.count() + assert gt_count == self.num_tokens, "Half the grants should have been deleted." @pytest.mark.django_db