Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Use MXCUri to simplify MediaRetentionTestCase.
Browse files Browse the repository at this point in the history
Helps clean this test case up quite a bit.
  • Loading branch information
anoadragon453 committed Jul 1, 2022
1 parent 76b14ca commit 23bdd30
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 68 deletions.
6 changes: 3 additions & 3 deletions synapse/rest/media/v1/media_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from synapse.http.site import SynapseRequest
from synapse.logging.context import defer_to_thread
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import UserID
from synapse.types import MXCUri, UserID
from synapse.util.async_helpers import Linearizer
from synapse.util.retryutils import NotRetryingDestination
from synapse.util.stringutils import random_string
Expand Down Expand Up @@ -187,7 +187,7 @@ async def create_content(
content: IO,
content_length: int,
auth_user: UserID,
) -> str:
) -> MXCUri:
"""Store uploaded content for a local user and return the mxc URL
Args:
Expand Down Expand Up @@ -220,7 +220,7 @@ async def create_content(

await self._generate_thumbnails(None, media_id, media_id, media_type)

return "mxc://%s/%s" % (self.server_name, media_id)
return MXCUri(self.server_name, media_id)

async def get_local_media(
self, request: SynapseRequest, media_id: str, name: Optional[str]
Expand Down
102 changes: 37 additions & 65 deletions tests/rest/media/test_media_retention.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
# limitations under the License.

import io
from typing import Iterable, Optional, Tuple
from typing import Iterable, Optional

from twisted.test.proto_helpers import MemoryReactor

from synapse.rest import admin
from synapse.rest.client import login, register, room
from synapse.server import HomeServer
from synapse.types import UserID
from synapse.types import MXCUri, UserID
from synapse.util import Clock

from tests import unittest
Expand Down Expand Up @@ -63,9 +63,9 @@ def _create_media_and_set_attributes(
last_accessed_ms: Optional[int],
is_quarantined: Optional[bool] = False,
is_protected: Optional[bool] = False,
) -> str:
) -> MXCUri:
# "Upload" some media to the local media store
mxc_uri = self.get_success(
mxc_uri: MXCUri = self.get_success(
media_repository.create_content(
media_type="text/plain",
upload_name=None,
Expand All @@ -75,13 +75,11 @@ def _create_media_and_set_attributes(
)
)

media_id = mxc_uri.split("/")[-1]

# Set the last recently accessed time for this media
if last_accessed_ms is not None:
self.get_success(
self.store.update_cached_last_access_time(
local_media=(media_id,),
local_media=(mxc_uri.media_id,),
remote_media=(),
time_ms=last_accessed_ms,
)
Expand All @@ -92,7 +90,7 @@ def _create_media_and_set_attributes(
self.get_success(
self.store.quarantine_media_by_id(
server_name=self.hs.config.server.server_name,
media_id=media_id,
media_id=mxc_uri.media_id,
quarantined_by="@theadmin:test",
)
)
Expand All @@ -101,18 +99,18 @@ def _create_media_and_set_attributes(
# Mark this media as protected from quarantine
self.get_success(
self.store.mark_local_media_as_safe(
media_id=media_id,
media_id=mxc_uri.media_id,
safe=True,
)
)

return media_id
return mxc_uri

def _cache_remote_media_and_set_attributes(
media_id: str,
last_accessed_ms: Optional[int],
is_quarantined: Optional[bool] = False,
) -> str:
) -> MXCUri:
# Pretend to cache some remote media
self.get_success(
self.store.store_cached_remote_media(
Expand Down Expand Up @@ -146,7 +144,7 @@ def _cache_remote_media_and_set_attributes(
)
)

return media_id
return MXCUri(self.remote_server_name, media_id)

# Start with the local media store
self.local_recently_accessed_media = _create_media_and_set_attributes(
Expand Down Expand Up @@ -214,28 +212,16 @@ def test_local_media_retention(self) -> None:
# Remote media should be unaffected.
self._assert_if_mxc_uris_purged(
purged=[
(
self.hs.config.server.server_name,
self.local_not_recently_accessed_media,
),
(self.hs.config.server.server_name, self.local_never_accessed_media),
self.local_not_recently_accessed_media,
self.local_never_accessed_media,
],
not_purged=[
(self.hs.config.server.server_name, self.local_recently_accessed_media),
(
self.hs.config.server.server_name,
self.local_not_recently_accessed_quarantined_media,
),
(
self.hs.config.server.server_name,
self.local_not_recently_accessed_protected_media,
),
(self.remote_server_name, self.remote_recently_accessed_media),
(self.remote_server_name, self.remote_not_recently_accessed_media),
(
self.remote_server_name,
self.remote_not_recently_accessed_quarantined_media,
),
self.local_recently_accessed_media,
self.local_not_recently_accessed_quarantined_media,
self.local_not_recently_accessed_protected_media,
self.remote_recently_accessed_media,
self.remote_not_recently_accessed_media,
self.remote_not_recently_accessed_quarantined_media,
],
)

Expand All @@ -261,49 +247,35 @@ def test_remote_media_cache_retention(self) -> None:
# Remote media accessed <30 days ago should still exist.
self._assert_if_mxc_uris_purged(
purged=[
(self.remote_server_name, self.remote_not_recently_accessed_media),
self.remote_not_recently_accessed_media,
],
not_purged=[
(self.remote_server_name, self.remote_recently_accessed_media),
(self.hs.config.server.server_name, self.local_recently_accessed_media),
(
self.hs.config.server.server_name,
self.local_not_recently_accessed_media,
),
(
self.hs.config.server.server_name,
self.local_not_recently_accessed_quarantined_media,
),
(
self.hs.config.server.server_name,
self.local_not_recently_accessed_protected_media,
),
(
self.remote_server_name,
self.remote_not_recently_accessed_quarantined_media,
),
(self.hs.config.server.server_name, self.local_never_accessed_media),
self.remote_recently_accessed_media,
self.local_recently_accessed_media,
self.local_not_recently_accessed_media,
self.local_not_recently_accessed_quarantined_media,
self.local_not_recently_accessed_protected_media,
self.remote_not_recently_accessed_quarantined_media,
self.local_never_accessed_media,
],
)

def _assert_if_mxc_uris_purged(
self, purged: Iterable[Tuple[str, str]], not_purged: Iterable[Tuple[str, str]]
self, purged: Iterable[MXCUri], not_purged: Iterable[MXCUri]
) -> None:
def _assert_mxc_uri_purge_state(
server_name: str, media_id: str, expect_purged: bool
) -> None:
def _assert_mxc_uri_purge_state(mxc_uri: MXCUri, expect_purged: bool) -> None:
"""Given an MXC URI, assert whether it has been purged or not."""
if server_name == self.hs.config.server.server_name:
if mxc_uri.server_name == self.hs.config.server.server_name:
found_media_dict = self.get_success(
self.store.get_local_media(media_id)
self.store.get_local_media(mxc_uri.media_id)
)
else:
found_media_dict = self.get_success(
self.store.get_cached_remote_media(server_name, media_id)
self.store.get_cached_remote_media(
mxc_uri.server_name, mxc_uri.media_id
)
)

mxc_uri = f"mxc://{server_name}/{media_id}"

if expect_purged:
self.assertIsNone(
found_media_dict, msg=f"{mxc_uri} unexpectedly not purged"
Expand All @@ -315,7 +287,7 @@ def _assert_mxc_uri_purge_state(
)

# Assert that the given MXC URIs have either been correctly purged or not.
for server_name, media_id in purged:
_assert_mxc_uri_purge_state(server_name, media_id, expect_purged=True)
for server_name, media_id in not_purged:
_assert_mxc_uri_purge_state(server_name, media_id, expect_purged=False)
for mxc_uri in purged:
_assert_mxc_uri_purge_state(mxc_uri, expect_purged=True)
for mxc_uri in not_purged:
_assert_mxc_uri_purge_state(mxc_uri, expect_purged=False)

0 comments on commit 23bdd30

Please sign in to comment.