Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Use HTTPStatus constants in place of literals in tests #13298

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/13298.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Use `HTTPStatus` constants in place of literals in tests.
79 changes: 43 additions & 36 deletions tests/rest/client/test_login.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import json
import time
import urllib.parse
from http import HTTPStatus
from typing import Any, Dict, List, Optional
from unittest.mock import Mock
from urllib.parse import urlencode
Expand Down Expand Up @@ -134,10 +135,12 @@ def test_POST_ratelimiting_per_address(self) -> None:
channel = self.make_request(b"POST", LOGIN_URL, params)

if i == 5:
self.assertEqual(channel.result["code"], b"429", channel.result)
self.assertEqual(
channel.code, HTTPStatus.TOO_MANY_REQUESTS, msg=channel.result
)
retry_after_ms = int(channel.json_body["retry_after_ms"])
else:
self.assertEqual(channel.result["code"], b"200", channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result)

# Since we're ratelimiting at 1 request/min, retry_after_ms should be lower
# than 1min.
Expand All @@ -152,7 +155,7 @@ def test_POST_ratelimiting_per_address(self) -> None:
}
channel = self.make_request(b"POST", LOGIN_URL, params)

self.assertEqual(channel.result["code"], b"200", channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result)

@override_config(
{
Expand All @@ -179,10 +182,12 @@ def test_POST_ratelimiting_per_account(self) -> None:
channel = self.make_request(b"POST", LOGIN_URL, params)

if i == 5:
self.assertEqual(channel.result["code"], b"429", channel.result)
self.assertEqual(
channel.code, HTTPStatus.TOO_MANY_REQUESTS, msg=channel.result
)
retry_after_ms = int(channel.json_body["retry_after_ms"])
else:
self.assertEqual(channel.result["code"], b"200", channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result)

# Since we're ratelimiting at 1 request/min, retry_after_ms should be lower
# than 1min.
Expand All @@ -197,7 +202,7 @@ def test_POST_ratelimiting_per_account(self) -> None:
}
channel = self.make_request(b"POST", LOGIN_URL, params)

self.assertEqual(channel.result["code"], b"200", channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result)

@override_config(
{
Expand All @@ -224,10 +229,12 @@ def test_POST_ratelimiting_per_account_failed_attempts(self) -> None:
channel = self.make_request(b"POST", LOGIN_URL, params)

if i == 5:
self.assertEqual(channel.result["code"], b"429", channel.result)
self.assertEqual(
channel.code, HTTPStatus.TOO_MANY_REQUESTS, msg=channel.result
)
retry_after_ms = int(channel.json_body["retry_after_ms"])
else:
self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, msg=channel.result)

# Since we're ratelimiting at 1 request/min, retry_after_ms should be lower
# than 1min.
Expand All @@ -242,15 +249,15 @@ def test_POST_ratelimiting_per_account_failed_attempts(self) -> None:
}
channel = self.make_request(b"POST", LOGIN_URL, params)

self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, msg=channel.result)

@override_config({"session_lifetime": "24h"})
def test_soft_logout(self) -> None:
self.register_user("kermit", "monkey")

# we shouldn't be able to make requests without an access token
channel = self.make_request(b"GET", TEST_URL)
self.assertEqual(channel.result["code"], b"401", channel.result)
self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_MISSING_TOKEN")

# log in as normal
Expand Down Expand Up @@ -354,7 +361,7 @@ def test_session_can_hard_logout_after_being_soft_logged_out(self) -> None:

# Now try to hard logout this session
channel = self.make_request(b"POST", "/logout", access_token=access_token)
self.assertEqual(channel.result["code"], b"200", channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result)

@override_config({"session_lifetime": "24h"})
def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out(
Expand All @@ -380,7 +387,7 @@ def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out(

# Now try to hard log out all of the user's sessions
channel = self.make_request(b"POST", "/logout/all", access_token=access_token)
self.assertEqual(channel.result["code"], b"200", channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result)

def test_login_with_overly_long_device_id_fails(self) -> None:
self.register_user("mickey", "cheese")
Expand Down Expand Up @@ -878,17 +885,17 @@ def jwt_login(self, *args: Any) -> FakeChannel:
def test_login_jwt_valid_registered(self) -> None:
self.register_user("kermit", "monkey")
channel = self.jwt_login({"sub": "kermit"})
self.assertEqual(channel.result["code"], b"200", channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test")

def test_login_jwt_valid_unregistered(self) -> None:
channel = self.jwt_login({"sub": "frog"})
self.assertEqual(channel.result["code"], b"200", channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@frog:test")

def test_login_jwt_invalid_signature(self) -> None:
channel = self.jwt_login({"sub": "frog"}, "notsecret")
self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"],
Expand All @@ -897,7 +904,7 @@ def test_login_jwt_invalid_signature(self) -> None:

def test_login_jwt_expired(self) -> None:
channel = self.jwt_login({"sub": "frog", "exp": 864000})
self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"],
Expand All @@ -907,7 +914,7 @@ def test_login_jwt_expired(self) -> None:
def test_login_jwt_not_before(self) -> None:
now = int(time.time())
channel = self.jwt_login({"sub": "frog", "nbf": now + 3600})
self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"],
Expand All @@ -916,7 +923,7 @@ def test_login_jwt_not_before(self) -> None:

def test_login_no_sub(self) -> None:
channel = self.jwt_login({"username": "root"})
self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(channel.json_body["error"], "Invalid JWT")

Expand All @@ -925,12 +932,12 @@ def test_login_iss(self) -> None:
"""Test validating the issuer claim."""
# A valid issuer.
channel = self.jwt_login({"sub": "kermit", "iss": "test-issuer"})
self.assertEqual(channel.result["code"], b"200", channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test")

# An invalid issuer.
channel = self.jwt_login({"sub": "kermit", "iss": "invalid"})
self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"],
Expand All @@ -939,7 +946,7 @@ def test_login_iss(self) -> None:

# Not providing an issuer.
channel = self.jwt_login({"sub": "kermit"})
self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"],
Expand All @@ -949,20 +956,20 @@ def test_login_iss(self) -> None:
def test_login_iss_no_config(self) -> None:
"""Test providing an issuer claim without requiring it in the configuration."""
channel = self.jwt_login({"sub": "kermit", "iss": "invalid"})
self.assertEqual(channel.result["code"], b"200", channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test")

@override_config({"jwt_config": {**base_config, "audiences": ["test-audience"]}})
def test_login_aud(self) -> None:
"""Test validating the audience claim."""
# A valid audience.
channel = self.jwt_login({"sub": "kermit", "aud": "test-audience"})
self.assertEqual(channel.result["code"], b"200", channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test")

# An invalid audience.
channel = self.jwt_login({"sub": "kermit", "aud": "invalid"})
self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"],
Expand All @@ -971,7 +978,7 @@ def test_login_aud(self) -> None:

# Not providing an audience.
channel = self.jwt_login({"sub": "kermit"})
self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"],
Expand All @@ -981,7 +988,7 @@ def test_login_aud(self) -> None:
def test_login_aud_no_config(self) -> None:
"""Test providing an audience without requiring it in the configuration."""
channel = self.jwt_login({"sub": "kermit", "aud": "invalid"})
self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"],
Expand All @@ -991,20 +998,20 @@ def test_login_aud_no_config(self) -> None:
def test_login_default_sub(self) -> None:
"""Test reading user ID from the default subject claim."""
channel = self.jwt_login({"sub": "kermit"})
self.assertEqual(channel.result["code"], b"200", channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test")

@override_config({"jwt_config": {**base_config, "subject_claim": "username"}})
def test_login_custom_sub(self) -> None:
"""Test reading user ID from a custom subject claim."""
channel = self.jwt_login({"username": "frog"})
self.assertEqual(channel.result["code"], b"200", channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@frog:test")

def test_login_no_token(self) -> None:
params = {"type": "org.matrix.login.jwt"}
channel = self.make_request(b"POST", LOGIN_URL, params)
self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(channel.json_body["error"], "Token field for JWT is missing")

Expand Down Expand Up @@ -1086,12 +1093,12 @@ def jwt_login(self, *args: Any) -> FakeChannel:

def test_login_jwt_valid(self) -> None:
channel = self.jwt_login({"sub": "kermit"})
self.assertEqual(channel.result["code"], b"200", channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test")

def test_login_jwt_invalid_signature(self) -> None:
channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey)
self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"],
Expand Down Expand Up @@ -1152,7 +1159,7 @@ def test_login_appservice_user(self) -> None:
b"POST", LOGIN_URL, params, access_token=self.service.token
)

self.assertEqual(channel.result["code"], b"200", channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result)

def test_login_appservice_user_bot(self) -> None:
"""Test that the appservice bot can use /login"""
Expand All @@ -1166,7 +1173,7 @@ def test_login_appservice_user_bot(self) -> None:
b"POST", LOGIN_URL, params, access_token=self.service.token
)

self.assertEqual(channel.result["code"], b"200", channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result)

def test_login_appservice_wrong_user(self) -> None:
"""Test that non-as users cannot login with the as token"""
Expand All @@ -1180,7 +1187,7 @@ def test_login_appservice_wrong_user(self) -> None:
b"POST", LOGIN_URL, params, access_token=self.service.token
)

self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, msg=channel.result)

def test_login_appservice_wrong_as(self) -> None:
"""Test that as users cannot login with wrong as token"""
Expand All @@ -1194,7 +1201,7 @@ def test_login_appservice_wrong_as(self) -> None:
b"POST", LOGIN_URL, params, access_token=self.another_service.token
)

self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, msg=channel.result)

def test_login_appservice_no_token(self) -> None:
"""Test that users must provide a token when using the appservice
Expand All @@ -1208,7 +1215,7 @@ def test_login_appservice_no_token(self) -> None:
}
channel = self.make_request(b"POST", LOGIN_URL, params)

self.assertEqual(channel.result["code"], b"401", channel.result)
self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, msg=channel.result)


@skip_unless(HAS_OIDC, "requires OIDC")
Expand Down
31 changes: 24 additions & 7 deletions tests/rest/client/test_redactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from http import HTTPStatus
from typing import List

from twisted.test.proto_helpers import MemoryReactor
Expand Down Expand Up @@ -67,7 +68,11 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
)

def _redact_event(
self, access_token: str, room_id: str, event_id: str, expect_code: int = 200
self,
access_token: str,
room_id: str,
event_id: str,
expect_code: int = HTTPStatus.OK,
) -> JsonDict:
"""Helper function to send a redaction event.

Expand All @@ -76,12 +81,12 @@ def _redact_event(
path = "/_matrix/client/r0/rooms/%s/redact/%s" % (room_id, event_id)

channel = self.make_request("POST", path, content={}, access_token=access_token)
self.assertEqual(int(channel.result["code"]), expect_code)
self.assertEqual(channel.code, expect_code)
return channel.json_body

def _sync_room_timeline(self, access_token: str, room_id: str) -> List[JsonDict]:
channel = self.make_request("GET", "sync", access_token=self.mod_access_token)
self.assertEqual(channel.result["code"], b"200")
self.assertEqual(channel.code, HTTPStatus.OK)
room_sync = channel.json_body["rooms"]["join"][room_id]
return room_sync["timeline"]["events"]

Expand Down Expand Up @@ -117,7 +122,10 @@ def test_redact_event_as_normal(self) -> None:

# as a normal, try to redact the admin's event
self._redact_event(
self.other_access_token, self.room_id, admin_msg_id, expect_code=403
self.other_access_token,
self.room_id,
admin_msg_id,
expect_code=HTTPStatus.FORBIDDEN,
)

# now try to redact our own event
Expand Down Expand Up @@ -153,7 +161,10 @@ def test_redact_nonexistent_event(self) -> None:

# ... but normals cannot
self._redact_event(
self.other_access_token, self.room_id, "$zzz", expect_code=404
self.other_access_token,
self.room_id,
"$zzz",
expect_code=HTTPStatus.NOT_FOUND,
)

# when we sync, we should see only the valid redaction
Expand All @@ -178,12 +189,18 @@ def test_redact_create_event(self) -> None:

# room moderators cannot send redactions for create events
self._redact_event(
self.mod_access_token, self.room_id, create_event_id, expect_code=403
self.mod_access_token,
self.room_id,
create_event_id,
expect_code=HTTPStatus.FORBIDDEN,
)

# and nor can normals
self._redact_event(
self.other_access_token, self.room_id, create_event_id, expect_code=403
self.other_access_token,
self.room_id,
create_event_id,
expect_code=HTTPStatus.FORBIDDEN,
)

def test_redact_event_as_moderator_ratelimit(self) -> None:
Expand Down
Loading