diff --git a/changelog.d/13298.misc b/changelog.d/13298.misc new file mode 100644 index 000000000000..545a62369f43 --- /dev/null +++ b/changelog.d/13298.misc @@ -0,0 +1 @@ +Use `HTTPStatus` constants in place of literals in tests. \ No newline at end of file diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index a2958f6959e7..0553073fa5e0 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -134,10 +134,12 @@ def test_POST_ratelimiting_per_address(self) -> None: channel = self.make_request(b"POST", LOGIN_URL, params) if i == 5: - self.assertEqual(channel.result["code"], b"429", channel.result) + self.assertEqual( + channel.code, HTTPStatus.TOO_MANY_REQUESTS, msg=channel.result + ) retry_after_ms = int(channel.json_body["retry_after_ms"]) else: - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower # than 1min. @@ -152,7 +154,7 @@ def test_POST_ratelimiting_per_address(self) -> None: } channel = self.make_request(b"POST", LOGIN_URL, params) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) @override_config( { @@ -179,10 +181,12 @@ def test_POST_ratelimiting_per_account(self) -> None: channel = self.make_request(b"POST", LOGIN_URL, params) if i == 5: - self.assertEqual(channel.result["code"], b"429", channel.result) + self.assertEqual( + channel.code, HTTPStatus.TOO_MANY_REQUESTS, msg=channel.result + ) retry_after_ms = int(channel.json_body["retry_after_ms"]) else: - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower # than 1min. @@ -197,7 +201,7 @@ def test_POST_ratelimiting_per_account(self) -> None: } channel = self.make_request(b"POST", LOGIN_URL, params) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) @override_config( { @@ -224,10 +228,12 @@ def test_POST_ratelimiting_per_account_failed_attempts(self) -> None: channel = self.make_request(b"POST", LOGIN_URL, params) if i == 5: - self.assertEqual(channel.result["code"], b"429", channel.result) + self.assertEqual( + channel.code, HTTPStatus.TOO_MANY_REQUESTS, msg=channel.result + ) retry_after_ms = int(channel.json_body["retry_after_ms"]) else: - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, msg=channel.result) # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower # than 1min. @@ -242,7 +248,7 @@ def test_POST_ratelimiting_per_account_failed_attempts(self) -> None: } channel = self.make_request(b"POST", LOGIN_URL, params) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, msg=channel.result) @override_config({"session_lifetime": "24h"}) def test_soft_logout(self) -> None: @@ -250,7 +256,7 @@ def test_soft_logout(self) -> None: # we shouldn't be able to make requests without an access token channel = self.make_request(b"GET", TEST_URL) - self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_MISSING_TOKEN") # log in as normal @@ -354,7 +360,7 @@ def test_session_can_hard_logout_after_being_soft_logged_out(self) -> None: # Now try to hard logout this session channel = self.make_request(b"POST", "/logout", access_token=access_token) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) @override_config({"session_lifetime": "24h"}) def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out( @@ -380,7 +386,7 @@ def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out( # Now try to hard log out all of the user's sessions channel = self.make_request(b"POST", "/logout/all", access_token=access_token) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) def test_login_with_overly_long_device_id_fails(self) -> None: self.register_user("mickey", "cheese") @@ -878,17 +884,17 @@ def jwt_login(self, *args: Any) -> FakeChannel: def test_login_jwt_valid_registered(self) -> None: self.register_user("kermit", "monkey") channel = self.jwt_login({"sub": "kermit"}) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) self.assertEqual(channel.json_body["user_id"], "@kermit:test") def test_login_jwt_valid_unregistered(self) -> None: channel = self.jwt_login({"sub": "frog"}) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) self.assertEqual(channel.json_body["user_id"], "@frog:test") def test_login_jwt_invalid_signature(self) -> None: channel = self.jwt_login({"sub": "frog"}, "notsecret") - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( channel.json_body["error"], @@ -897,7 +903,7 @@ def test_login_jwt_invalid_signature(self) -> None: def test_login_jwt_expired(self) -> None: channel = self.jwt_login({"sub": "frog", "exp": 864000}) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( channel.json_body["error"], @@ -907,7 +913,7 @@ def test_login_jwt_expired(self) -> None: def test_login_jwt_not_before(self) -> None: now = int(time.time()) channel = self.jwt_login({"sub": "frog", "nbf": now + 3600}) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( channel.json_body["error"], @@ -916,7 +922,7 @@ def test_login_jwt_not_before(self) -> None: def test_login_no_sub(self) -> None: channel = self.jwt_login({"username": "root"}) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["error"], "Invalid JWT") @@ -925,12 +931,12 @@ def test_login_iss(self) -> None: """Test validating the issuer claim.""" # A valid issuer. channel = self.jwt_login({"sub": "kermit", "iss": "test-issuer"}) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) self.assertEqual(channel.json_body["user_id"], "@kermit:test") # An invalid issuer. channel = self.jwt_login({"sub": "kermit", "iss": "invalid"}) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( channel.json_body["error"], @@ -939,7 +945,7 @@ def test_login_iss(self) -> None: # Not providing an issuer. channel = self.jwt_login({"sub": "kermit"}) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( channel.json_body["error"], @@ -949,7 +955,7 @@ def test_login_iss(self) -> None: def test_login_iss_no_config(self) -> None: """Test providing an issuer claim without requiring it in the configuration.""" channel = self.jwt_login({"sub": "kermit", "iss": "invalid"}) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) self.assertEqual(channel.json_body["user_id"], "@kermit:test") @override_config({"jwt_config": {**base_config, "audiences": ["test-audience"]}}) @@ -957,12 +963,12 @@ def test_login_aud(self) -> None: """Test validating the audience claim.""" # A valid audience. channel = self.jwt_login({"sub": "kermit", "aud": "test-audience"}) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) self.assertEqual(channel.json_body["user_id"], "@kermit:test") # An invalid audience. channel = self.jwt_login({"sub": "kermit", "aud": "invalid"}) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( channel.json_body["error"], @@ -971,7 +977,7 @@ def test_login_aud(self) -> None: # Not providing an audience. channel = self.jwt_login({"sub": "kermit"}) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( channel.json_body["error"], @@ -981,7 +987,7 @@ def test_login_aud(self) -> None: def test_login_aud_no_config(self) -> None: """Test providing an audience without requiring it in the configuration.""" channel = self.jwt_login({"sub": "kermit", "aud": "invalid"}) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( channel.json_body["error"], @@ -991,20 +997,20 @@ def test_login_aud_no_config(self) -> None: def test_login_default_sub(self) -> None: """Test reading user ID from the default subject claim.""" channel = self.jwt_login({"sub": "kermit"}) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) self.assertEqual(channel.json_body["user_id"], "@kermit:test") @override_config({"jwt_config": {**base_config, "subject_claim": "username"}}) def test_login_custom_sub(self) -> None: """Test reading user ID from a custom subject claim.""" channel = self.jwt_login({"username": "frog"}) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) self.assertEqual(channel.json_body["user_id"], "@frog:test") def test_login_no_token(self) -> None: params = {"type": "org.matrix.login.jwt"} channel = self.make_request(b"POST", LOGIN_URL, params) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["error"], "Token field for JWT is missing") @@ -1086,12 +1092,12 @@ def jwt_login(self, *args: Any) -> FakeChannel: def test_login_jwt_valid(self) -> None: channel = self.jwt_login({"sub": "kermit"}) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) self.assertEqual(channel.json_body["user_id"], "@kermit:test") def test_login_jwt_invalid_signature(self) -> None: channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( channel.json_body["error"], @@ -1152,7 +1158,7 @@ def test_login_appservice_user(self) -> None: b"POST", LOGIN_URL, params, access_token=self.service.token ) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) def test_login_appservice_user_bot(self) -> None: """Test that the appservice bot can use /login""" @@ -1166,7 +1172,7 @@ def test_login_appservice_user_bot(self) -> None: b"POST", LOGIN_URL, params, access_token=self.service.token ) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) def test_login_appservice_wrong_user(self) -> None: """Test that non-as users cannot login with the as token""" @@ -1180,7 +1186,7 @@ def test_login_appservice_wrong_user(self) -> None: b"POST", LOGIN_URL, params, access_token=self.service.token ) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, msg=channel.result) def test_login_appservice_wrong_as(self) -> None: """Test that as users cannot login with wrong as token""" @@ -1194,7 +1200,7 @@ def test_login_appservice_wrong_as(self) -> None: b"POST", LOGIN_URL, params, access_token=self.another_service.token ) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, msg=channel.result) def test_login_appservice_no_token(self) -> None: """Test that users must provide a token when using the appservice @@ -1208,7 +1214,7 @@ def test_login_appservice_no_token(self) -> None: } channel = self.make_request(b"POST", LOGIN_URL, params) - self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, msg=channel.result) @skip_unless(HAS_OIDC, "requires OIDC") diff --git a/tests/rest/client/test_redactions.py b/tests/rest/client/test_redactions.py index 7401b5e0c0fa..909c017e8840 100644 --- a/tests/rest/client/test_redactions.py +++ b/tests/rest/client/test_redactions.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from http import HTTPStatus from typing import List from twisted.test.proto_helpers import MemoryReactor @@ -67,7 +68,11 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: ) def _redact_event( - self, access_token: str, room_id: str, event_id: str, expect_code: int = 200 + self, + access_token: str, + room_id: str, + event_id: str, + expect_code: int = HTTPStatus.OK, ) -> JsonDict: """Helper function to send a redaction event. @@ -76,12 +81,12 @@ def _redact_event( path = "/_matrix/client/r0/rooms/%s/redact/%s" % (room_id, event_id) channel = self.make_request("POST", path, content={}, access_token=access_token) - self.assertEqual(int(channel.result["code"]), expect_code) + self.assertEqual(channel.code, expect_code) return channel.json_body def _sync_room_timeline(self, access_token: str, room_id: str) -> List[JsonDict]: channel = self.make_request("GET", "sync", access_token=self.mod_access_token) - self.assertEqual(channel.result["code"], b"200") + self.assertEqual(channel.code, HTTPStatus.OK) room_sync = channel.json_body["rooms"]["join"][room_id] return room_sync["timeline"]["events"] @@ -117,7 +122,10 @@ def test_redact_event_as_normal(self) -> None: # as a normal, try to redact the admin's event self._redact_event( - self.other_access_token, self.room_id, admin_msg_id, expect_code=403 + self.other_access_token, + self.room_id, + admin_msg_id, + expect_code=HTTPStatus.FORBIDDEN, ) # now try to redact our own event @@ -153,7 +161,10 @@ def test_redact_nonexistent_event(self) -> None: # ... but normals cannot self._redact_event( - self.other_access_token, self.room_id, "$zzz", expect_code=404 + self.other_access_token, + self.room_id, + "$zzz", + expect_code=HTTPStatus.NOT_FOUND, ) # when we sync, we should see only the valid redaction @@ -178,12 +189,18 @@ def test_redact_create_event(self) -> None: # room moderators cannot send redactions for create events self._redact_event( - self.mod_access_token, self.room_id, create_event_id, expect_code=403 + self.mod_access_token, + self.room_id, + create_event_id, + expect_code=HTTPStatus.FORBIDDEN, ) # and nor can normals self._redact_event( - self.other_access_token, self.room_id, create_event_id, expect_code=403 + self.other_access_token, + self.room_id, + create_event_id, + expect_code=HTTPStatus.FORBIDDEN, ) def test_redact_event_as_moderator_ratelimit(self) -> None: diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py index f8e64ce6ac9c..2326fbf0294b 100644 --- a/tests/rest/client/test_register.py +++ b/tests/rest/client/test_register.py @@ -15,6 +15,7 @@ # limitations under the License. import datetime import os +from http import HTTPStatus from typing import Any, Dict, List, Tuple import pkg_resources @@ -70,7 +71,7 @@ def test_POST_appservice_registration_valid(self) -> None: b"POST", self.url + b"?access_token=i_am_an_app_service", request_data ) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) det_data = {"user_id": user_id, "home_server": self.hs.hostname} self.assertDictContainsSubset(det_data, channel.json_body) @@ -91,7 +92,7 @@ def test_POST_appservice_registration_no_type(self) -> None: b"POST", self.url + b"?access_token=i_am_an_app_service", request_data ) - self.assertEqual(channel.result["code"], b"400", channel.result) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, msg=channel.result) def test_POST_appservice_registration_invalid(self) -> None: self.appservice = None # no application service exists @@ -100,20 +101,20 @@ def test_POST_appservice_registration_invalid(self) -> None: b"POST", self.url + b"?access_token=i_am_an_app_service", request_data ) - self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, msg=channel.result) def test_POST_bad_password(self) -> None: request_data = {"username": "kermit", "password": 666} channel = self.make_request(b"POST", self.url, request_data) - self.assertEqual(channel.result["code"], b"400", channel.result) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, msg=channel.result) self.assertEqual(channel.json_body["error"], "Invalid password") def test_POST_bad_username(self) -> None: request_data = {"username": 777, "password": "monkey"} channel = self.make_request(b"POST", self.url, request_data) - self.assertEqual(channel.result["code"], b"400", channel.result) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, msg=channel.result) self.assertEqual(channel.json_body["error"], "Invalid username") def test_POST_user_valid(self) -> None: @@ -132,7 +133,7 @@ def test_POST_user_valid(self) -> None: "home_server": self.hs.hostname, "device_id": device_id, } - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) self.assertDictContainsSubset(det_data, channel.json_body) @override_config({"enable_registration": False}) @@ -142,7 +143,7 @@ def test_POST_disabled_registration(self) -> None: channel = self.make_request(b"POST", self.url, request_data) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, msg=channel.result) self.assertEqual(channel.json_body["error"], "Registration has been disabled") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") @@ -153,7 +154,7 @@ def test_POST_guest_registration(self) -> None: channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") det_data = {"home_server": self.hs.hostname, "device_id": "guest_device"} - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) self.assertDictContainsSubset(det_data, channel.json_body) def test_POST_disabled_guest_registration(self) -> None: @@ -161,7 +162,7 @@ def test_POST_disabled_guest_registration(self) -> None: channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, msg=channel.result) self.assertEqual(channel.json_body["error"], "Guest access is disabled") @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}}) @@ -171,16 +172,18 @@ def test_POST_ratelimiting_guest(self) -> None: channel = self.make_request(b"POST", url, b"{}") if i == 5: - self.assertEqual(channel.result["code"], b"429", channel.result) + self.assertEqual( + channel.code, HTTPStatus.TOO_MANY_REQUESTS, msg=channel.result + ) retry_after_ms = int(channel.json_body["retry_after_ms"]) else: - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) self.reactor.advance(retry_after_ms / 1000.0 + 1.0) channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}}) def test_POST_ratelimiting(self) -> None: @@ -194,16 +197,18 @@ def test_POST_ratelimiting(self) -> None: channel = self.make_request(b"POST", self.url, request_data) if i == 5: - self.assertEqual(channel.result["code"], b"429", channel.result) + self.assertEqual( + channel.code, HTTPStatus.TOO_MANY_REQUESTS, msg=channel.result + ) retry_after_ms = int(channel.json_body["retry_after_ms"]) else: - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) self.reactor.advance(retry_after_ms / 1000.0 + 1.0) channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) @override_config({"registration_requires_token": True}) def test_POST_registration_requires_token(self) -> None: @@ -231,7 +236,7 @@ def test_POST_registration_requires_token(self) -> None: # Request without auth to get flows and session channel = self.make_request(b"POST", self.url, params) - self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, msg=channel.result) flows = channel.json_body["flows"] # Synapse adds a dummy stage to differentiate flows where otherwise one # flow would be a subset of another flow. @@ -248,7 +253,7 @@ def test_POST_registration_requires_token(self) -> None: "session": session, } channel = self.make_request(b"POST", self.url, params) - self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, msg=channel.result) completed = channel.json_body["completed"] self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed) @@ -263,7 +268,7 @@ def test_POST_registration_requires_token(self) -> None: "home_server": self.hs.hostname, "device_id": device_id, } - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) self.assertDictContainsSubset(det_data, channel.json_body) # Check the `completed` counter has been incremented and pending is 0 @@ -293,21 +298,21 @@ def test_POST_registration_token_invalid(self) -> None: "session": session, } channel = self.make_request(b"POST", self.url, params) - self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, msg=channel.result) self.assertEqual(channel.json_body["errcode"], Codes.MISSING_PARAM) self.assertEqual(channel.json_body["completed"], []) # Test with non-string (invalid) params["auth"]["token"] = 1234 channel = self.make_request(b"POST", self.url, params) - self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, msg=channel.result) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) self.assertEqual(channel.json_body["completed"], []) # Test with unknown token (invalid) params["auth"]["token"] = "1234" channel = self.make_request(b"POST", self.url, params) - self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, msg=channel.result) self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED) self.assertEqual(channel.json_body["completed"], []) @@ -361,7 +366,7 @@ def test_POST_registration_token_limit_uses(self) -> None: "session": session2, } channel = self.make_request(b"POST", self.url, params2) - self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, msg=channel.result) self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED) self.assertEqual(channel.json_body["completed"], []) @@ -381,7 +386,7 @@ def test_POST_registration_token_limit_uses(self) -> None: # Check auth still fails when using token with session2 channel = self.make_request(b"POST", self.url, params2) - self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, msg=channel.result) self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED) self.assertEqual(channel.json_body["completed"], []) @@ -415,7 +420,7 @@ def test_POST_registration_token_expiry(self) -> None: "session": session, } channel = self.make_request(b"POST", self.url, params) - self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, msg=channel.result) self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED) self.assertEqual(channel.json_body["completed"], []) @@ -570,7 +575,7 @@ def test_POST_registration_token_session_expiry_deleted_token(self) -> None: def test_advertised_flows(self) -> None: channel = self.make_request(b"POST", self.url, b"{}") - self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, msg=channel.result) flows = channel.json_body["flows"] # with the stock config, we only expect the dummy flow @@ -593,7 +598,7 @@ def test_advertised_flows(self) -> None: ) def test_advertised_flows_captcha_and_terms_and_3pids(self) -> None: channel = self.make_request(b"POST", self.url, b"{}") - self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, msg=channel.result) flows = channel.json_body["flows"] self.assertCountEqual( @@ -625,7 +630,7 @@ def test_advertised_flows_captcha_and_terms_and_3pids(self) -> None: ) def test_advertised_flows_no_msisdn_email_required(self) -> None: channel = self.make_request(b"POST", self.url, b"{}") - self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, msg=channel.result) flows = channel.json_body["flows"] # with the stock config, we expect all four combinations of 3pid @@ -669,7 +674,7 @@ def test_request_token_existing_email_inhibit_error(self) -> None: b"register/email/requestToken", {"client_secret": "foobar", "email": email, "send_attempt": 1}, ) - self.assertEqual(200, channel.code, channel.result) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertIsNotNone(channel.json_body.get("sid")) @@ -692,7 +697,7 @@ def test_reject_invalid_email(self) -> None: b"register/email/requestToken", {"client_secret": "foobar", "email": "email@@email", "send_attempt": 1}, ) - self.assertEqual(400, channel.code, channel.result) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result) # Check error to ensure that we're not erroring due to a bug in the test. self.assertEqual( channel.json_body, @@ -705,7 +710,7 @@ def test_reject_invalid_email(self) -> None: b"register/email/requestToken", {"client_secret": "foobar", "email": "email", "send_attempt": 1}, ) - self.assertEqual(400, channel.code, channel.result) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result) self.assertEqual( channel.json_body, {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"}, @@ -718,7 +723,7 @@ def test_reject_invalid_email(self) -> None: b"register/email/requestToken", {"client_secret": "foobar", "email": email, "send_attempt": 1}, ) - self.assertEqual(400, channel.code, channel.result) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result) self.assertEqual( channel.json_body, {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"}, @@ -743,7 +748,7 @@ def test_inhibit_user_in_use_error(self) -> None: # Check that /available correctly ignores the username provided despite the # username being already registered. channel = self.make_request("GET", "register/available?username=" + username) - self.assertEqual(200, channel.code, channel.result) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) # Test that when starting a UIA registration flow the request doesn't fail because # of a conflicting username @@ -752,7 +757,7 @@ def test_inhibit_user_in_use_error(self) -> None: "register", {"username": username, "type": "m.login.password", "password": "foo"}, ) - self.assertEqual(channel.code, 401) + self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED) self.assertIn("session", channel.json_body) # Test that finishing the registration fails because of a conflicting username. @@ -762,7 +767,7 @@ def test_inhibit_user_in_use_error(self) -> None: "register", {"auth": {"session": session, "type": LoginType.DUMMY}}, ) - self.assertEqual(channel.code, 400, channel.json_body) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.json_body) self.assertEqual(channel.json_body["errcode"], Codes.USER_IN_USE) @@ -797,13 +802,13 @@ def test_validity_period(self) -> None: # endpoint. channel = self.make_request(b"GET", "/sync", access_token=tok) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) self.reactor.advance(datetime.timedelta(weeks=1).total_seconds()) channel = self.make_request(b"GET", "/sync", access_token=tok) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, msg=channel.result) self.assertEqual( channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result ) @@ -823,12 +828,12 @@ def test_manual_renewal(self) -> None: url = "/_synapse/admin/v1/account_validity/validity" request_data = {"user_id": user_id} channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) # The specific endpoint doesn't matter, all we need is an authenticated # endpoint. channel = self.make_request(b"GET", "/sync", access_token=tok) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) def test_manual_expire(self) -> None: user_id = self.register_user("kermit", "monkey") @@ -844,12 +849,12 @@ def test_manual_expire(self) -> None: "enable_renewal_emails": False, } channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) # The specific endpoint doesn't matter, all we need is an authenticated # endpoint. channel = self.make_request(b"GET", "/sync", access_token=tok) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, msg=channel.result) self.assertEqual( channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result ) @@ -868,18 +873,18 @@ def test_logging_out_expired_user(self) -> None: "enable_renewal_emails": False, } channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) # Try to log the user out channel = self.make_request(b"POST", "/logout", access_token=tok) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) # Log the user in again (allowed for expired accounts) tok = self.login("kermit", "monkey") # Try to log out all of the user's sessions channel = self.make_request(b"POST", "/logout/all", access_token=tok) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): @@ -954,7 +959,7 @@ def test_renewal_email(self) -> None: renewal_token = self.get_success(self.store.get_renewal_token_for_user(user_id)) url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token channel = self.make_request(b"GET", url) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) # Check that we're getting HTML back. content_type = channel.headers.getRawHeaders(b"Content-Type") @@ -972,7 +977,7 @@ def test_renewal_email(self) -> None: # Move 1 day forward. Try to renew with the same token again. url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token channel = self.make_request(b"GET", url) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) # Check that we're getting HTML back. content_type = channel.headers.getRawHeaders(b"Content-Type") @@ -992,14 +997,14 @@ def test_renewal_email(self) -> None: # succeed. self.reactor.advance(datetime.timedelta(days=3).total_seconds()) channel = self.make_request(b"GET", "/sync", access_token=tok) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) def test_renewal_invalid_token(self) -> None: # Hit the renewal endpoint with an invalid token and check that it behaves as # expected, i.e. that it responds with 404 Not Found and the correct HTML. url = "/_matrix/client/unstable/account_validity/renew?token=123" channel = self.make_request(b"GET", url) - self.assertEqual(channel.result["code"], b"404", channel.result) + self.assertEqual(channel.code, HTTPStatus.NOT_FOUND, msg=channel.result) # Check that we're getting HTML back. content_type = channel.headers.getRawHeaders(b"Content-Type") @@ -1023,7 +1028,7 @@ def test_manual_email_send(self) -> None: "/_matrix/client/unstable/account_validity/send_mail", access_token=tok, ) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) self.assertEqual(len(self.email_attempts), 1) @@ -1043,7 +1048,7 @@ def test_deactivated_user(self) -> None: channel = self.make_request( "POST", "account/deactivate", request_data, access_token=tok ) - self.assertEqual(channel.code, 200) + self.assertEqual(channel.code, HTTPStatus.OK) self.reactor.advance(datetime.timedelta(days=8).total_seconds()) @@ -1096,7 +1101,7 @@ def test_manual_email_send_expired_account(self) -> None: "/_matrix/client/unstable/account_validity/send_mail", access_token=tok, ) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) self.assertEqual(len(self.email_attempts), 1) @@ -1176,7 +1181,7 @@ def test_GET_token_valid(self) -> None: b"GET", f"{self.url}?token={token}", ) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) self.assertEqual(channel.json_body["valid"], True) def test_GET_token_invalid(self) -> None: @@ -1185,7 +1190,7 @@ def test_GET_token_invalid(self) -> None: b"GET", f"{self.url}?token={token}", ) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) self.assertEqual(channel.json_body["valid"], False) @override_config( @@ -1201,10 +1206,12 @@ def test_GET_ratelimiting(self) -> None: ) if i == 5: - self.assertEqual(channel.result["code"], b"429", channel.result) + self.assertEqual( + channel.code, HTTPStatus.TOO_MANY_REQUESTS, msg=channel.result + ) retry_after_ms = int(channel.json_body["retry_after_ms"]) else: - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) self.reactor.advance(retry_after_ms / 1000.0 + 1.0) @@ -1212,4 +1219,4 @@ def test_GET_ratelimiting(self) -> None: b"GET", f"{self.url}?token={token}", ) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result)