diff --git a/tests/test_types.py b/tests/test_types.py index d8d82a517ea9..8b8b5505c014 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -13,7 +13,7 @@ # limitations under the License. from synapse.api.errors import SynapseError -from synapse.types import RoomAlias, UserID, map_username_to_mxid_localpart +from synapse.types import MXCUri, RoomAlias, UserID, map_username_to_mxid_localpart from tests import unittest @@ -96,3 +96,81 @@ def testNonAscii(self): # this should work with either a unicode or a bytes self.assertEqual(map_username_to_mxid_localpart("têst"), "t=c3=aast") self.assertEqual(map_username_to_mxid_localpart("têst".encode()), "t=c3=aast") + + +class MXCUriTestCase(unittest.TestCase): + def test_valid_mxc_uris(self): + """Tests that a series of valid mxc uris are parsed correctly.""" + # Converting an MXCUri to its str representation + mxc_1 = MXCUri(server_name="example.com", media_id="84n8493hnfsjkbcu") + self.assertEqual(mxc_1.to_string(), "mxc://example.com/84n8493hnfsjkbcu") + + mxc_2 = MXCUri( + server_name="192.168.1.17:8008", media_id="bajkad89h31ausdhoqqasd" + ) + self.assertEqual( + mxc_2.to_string(), "mxc://192.168.1.17:8008/bajkad89h31ausdhoqqasd" + ) + + mxc_3 = MXCUri(server_name="123.123.123.123", media_id="000000000000") + self.assertEqual(mxc_3.to_string(), "mxc://123.123.123.123/000000000000") + + # Converting a str to its MXCUri representation + mxcuri_1 = MXCUri.from_str("mxc://example.com/g12789g890ajksjk") + self.assertEqual(mxcuri_1.server_name, "example.com") + self.assertEqual(mxcuri_1.media_id, "g12789g890ajksjk") + + mxcuri_2 = MXCUri.from_str("mxc://localhost:8448/abcdefghijklmnopqrstuvwxyz") + self.assertEqual(mxcuri_2.server_name, "localhost:8448") + self.assertEqual(mxcuri_2.media_id, "abcdefghijklmnopqrstuvwxyz") + + mxcuri_3 = MXCUri.from_str("mxc://[::1]/abcdefghijklmnopqrstuvwxyz") + self.assertEqual(mxcuri_3.server_name, "[::1]") + self.assertEqual(mxcuri_3.media_id, "abcdefghijklmnopqrstuvwxyz") + + mxcuri_4 = MXCUri.from_str("mxc://123.123.123.123:32112/12893y81283781023") + self.assertEqual(mxcuri_4.server_name, "123.123.123.123:32112") + self.assertEqual(mxcuri_4.media_id, "12893y81283781023") + + mxcuri_5 = MXCUri.from_str("mxc://domain/abcdefg") + self.assertEqual(mxcuri_5.server_name, "domain") + self.assertEqual(mxcuri_5.media_id, "abcdefg") + + def test_invalid_mxc_uris(self): + """Tests that a series of invalid mxc uris are appropriately rejected.""" + # Converting a str to its MXCUri representation + with self.assertRaises(ValueError): + MXCUri.from_str("http://example.com/abcdef") + + with self.assertRaises(ValueError): + MXCUri.from_str("mxc:///example.com/abcdef") + + with self.assertRaises(ValueError): + MXCUri.from_str("mxc://example.com//abcdef") + + with self.assertRaises(ValueError): + MXCUri.from_str("mxc://example.com/abcdef/") + + with self.assertRaises(ValueError): + MXCUri.from_str("mxc://example.com/abc/abcdef") + + with self.assertRaises(ValueError): + MXCUri.from_str("mxc://example.com/abc/abcdef") + + with self.assertRaises(ValueError): + MXCUri.from_str("mxc:///abcdef") + + with self.assertRaises(ValueError): + MXCUri.from_str("mxc://example.com") + + with self.assertRaises(ValueError): + MXCUri.from_str("mxc://example.com/") + + with self.assertRaises(ValueError): + MXCUri.from_str("mxc:///") + + with self.assertRaises(ValueError): + MXCUri.from_str("") + + with self.assertRaises(ValueError): + MXCUri.from_str(None) # type: ignore