Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Remove some boilerplate in tests #4156

Merged
merged 2 commits into from
Nov 6, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/4156.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
HTTP tests have been refactored to contain less boilerplate.
116 changes: 53 additions & 63 deletions tests/rest/client/v1/test_admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,17 @@

from mock import Mock

from synapse.http.server import JsonResource
from synapse.rest.client.v1.admin import register_servlets
from synapse.util import Clock

from tests import unittest
from tests.server import (
ThreadedMemoryReactorClock,
make_request,
render,
setup_test_homeserver,
)


class UserRegisterTestCase(unittest.TestCase):
def setUp(self):
class UserRegisterTestCase(unittest.HomeserverTestCase):

servlets = [register_servlets]

def make_homeserver(self, reactor, clock):

self.clock = ThreadedMemoryReactorClock()
self.hs_clock = Clock(self.clock)
self.url = "/_matrix/client/r0/admin/register"

self.registration_handler = Mock()
Expand All @@ -50,17 +43,14 @@ def setUp(self):

self.secrets = Mock()

self.hs = setup_test_homeserver(
self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock
)
self.hs = self.setup_test_homeserver()

self.hs.config.registration_shared_secret = u"shared"

self.hs.get_media_repository = Mock()
self.hs.get_deactivate_account_handler = Mock()

self.resource = JsonResource(self.hs)
register_servlets(self.hs, self.resource)
return self.hs

def test_disabled(self):
"""
Expand All @@ -69,8 +59,8 @@ def test_disabled(self):
"""
self.hs.config.registration_shared_secret = None

request, channel = make_request("POST", self.url, b'{}')
render(request, self.resource, self.clock)
request, channel = self.make_request("POST", self.url, b'{}')
self.render(request)

self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(
Expand All @@ -87,8 +77,8 @@ def test_get_nonce(self):

self.hs.get_secrets = Mock(return_value=secrets)

request, channel = make_request("GET", self.url)
render(request, self.resource, self.clock)
request, channel = self.make_request("GET", self.url)
self.render(request)

self.assertEqual(channel.json_body, {"nonce": "abcd"})

Expand All @@ -97,25 +87,25 @@ def test_expired_nonce(self):
Calling GET on the endpoint will return a randomised nonce, which will
only last for SALT_TIMEOUT (60s).
"""
request, channel = make_request("GET", self.url)
render(request, self.resource, self.clock)
request, channel = self.make_request("GET", self.url)
self.render(request)
nonce = channel.json_body["nonce"]

# 59 seconds
self.clock.advance(59)
self.reactor.advance(59)

body = json.dumps({"nonce": nonce})
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
request, channel = self.make_request("POST", self.url, body.encode('utf8'))
self.render(request)

self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('username must be specified', channel.json_body["error"])

# 61 seconds
self.clock.advance(2)
self.reactor.advance(2)

request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
request, channel = self.make_request("POST", self.url, body.encode('utf8'))
self.render(request)

self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('unrecognised nonce', channel.json_body["error"])
Expand All @@ -124,8 +114,8 @@ def test_register_incorrect_nonce(self):
"""
Only the provided nonce can be used, as it's checked in the MAC.
"""
request, channel = make_request("GET", self.url)
render(request, self.resource, self.clock)
request, channel = self.make_request("GET", self.url)
self.render(request)
nonce = channel.json_body["nonce"]

want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
Expand All @@ -141,8 +131,8 @@ def test_register_incorrect_nonce(self):
"mac": want_mac,
}
)
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
request, channel = self.make_request("POST", self.url, body.encode('utf8'))
self.render(request)

self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("HMAC incorrect", channel.json_body["error"])
Expand All @@ -152,8 +142,8 @@ def test_register_correct_nonce(self):
When the correct nonce is provided, and the right key is provided, the
user is registered.
"""
request, channel = make_request("GET", self.url)
render(request, self.resource, self.clock)
request, channel = self.make_request("GET", self.url)
self.render(request)
nonce = channel.json_body["nonce"]

want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
Expand All @@ -169,8 +159,8 @@ def test_register_correct_nonce(self):
"mac": want_mac,
}
)
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
request, channel = self.make_request("POST", self.url, body.encode('utf8'))
self.render(request)

self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["user_id"])
Expand All @@ -179,8 +169,8 @@ def test_nonce_reuse(self):
"""
A valid unrecognised nonce.
"""
request, channel = make_request("GET", self.url)
render(request, self.resource, self.clock)
request, channel = self.make_request("GET", self.url)
self.render(request)
nonce = channel.json_body["nonce"]

want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
Expand All @@ -196,15 +186,15 @@ def test_nonce_reuse(self):
"mac": want_mac,
}
)
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
request, channel = self.make_request("POST", self.url, body.encode('utf8'))
self.render(request)

self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["user_id"])

# Now, try and reuse it
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
request, channel = self.make_request("POST", self.url, body.encode('utf8'))
self.render(request)

self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('unrecognised nonce', channel.json_body["error"])
Expand All @@ -217,8 +207,8 @@ def test_missing_parts(self):
"""

def nonce():
request, channel = make_request("GET", self.url)
render(request, self.resource, self.clock)
request, channel = self.make_request("GET", self.url)
self.render(request)
return channel.json_body["nonce"]

#
Expand All @@ -227,8 +217,8 @@ def nonce():

# Must be present
body = json.dumps({})
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
request, channel = self.make_request("POST", self.url, body.encode('utf8'))
self.render(request)

self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('nonce must be specified', channel.json_body["error"])
Expand All @@ -239,32 +229,32 @@ def nonce():

# Must be present
body = json.dumps({"nonce": nonce()})
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
request, channel = self.make_request("POST", self.url, body.encode('utf8'))
self.render(request)

self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('username must be specified', channel.json_body["error"])

# Must be a string
body = json.dumps({"nonce": nonce(), "username": 1234})
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
request, channel = self.make_request("POST", self.url, body.encode('utf8'))
self.render(request)

self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('Invalid username', channel.json_body["error"])

# Must not have null bytes
body = json.dumps({"nonce": nonce(), "username": u"abcd\u0000"})
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
request, channel = self.make_request("POST", self.url, body.encode('utf8'))
self.render(request)

self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('Invalid username', channel.json_body["error"])

# Must not have null bytes
body = json.dumps({"nonce": nonce(), "username": "a" * 1000})
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
request, channel = self.make_request("POST", self.url, body.encode('utf8'))
self.render(request)

self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('Invalid username', channel.json_body["error"])
Expand All @@ -275,16 +265,16 @@ def nonce():

# Must be present
body = json.dumps({"nonce": nonce(), "username": "a"})
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
request, channel = self.make_request("POST", self.url, body.encode('utf8'))
self.render(request)

self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('password must be specified', channel.json_body["error"])

# Must be a string
body = json.dumps({"nonce": nonce(), "username": "a", "password": 1234})
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
request, channel = self.make_request("POST", self.url, body.encode('utf8'))
self.render(request)

self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('Invalid password', channel.json_body["error"])
Expand All @@ -293,16 +283,16 @@ def nonce():
body = json.dumps(
{"nonce": nonce(), "username": "a", "password": u"abcd\u0000"}
)
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
request, channel = self.make_request("POST", self.url, body.encode('utf8'))
self.render(request)

self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('Invalid password', channel.json_body["error"])

# Super long
body = json.dumps({"nonce": nonce(), "username": "a", "password": "A" * 1000})
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
request, channel = self.make_request("POST", self.url, body.encode('utf8'))
self.render(request)

self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('Invalid password', channel.json_body["error"])
10 changes: 5 additions & 5 deletions tests/rest/client/v1/test_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ def setUp(self):
)

handlers = Mock(registration_handler=self.registration_handler)
self.clock = MemoryReactorClock()
self.hs_clock = Clock(self.clock)
self.reactor = MemoryReactorClock()
self.hs_clock = Clock(self.reactor)

self.hs = self.hs = setup_test_homeserver(
self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock
self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.reactor
)
self.hs.get_datastore = Mock(return_value=self.datastore)
self.hs.get_handlers = Mock(return_value=handlers)
Expand All @@ -76,8 +76,8 @@ def test_POST_createuser_with_valid_user(self):
return_value=(user_id, token)
)

request, channel = make_request(b"POST", url, request_data)
render(request, res, self.clock)
request, channel = make_request(self.reactor, b"POST", url, request_data)
render(request, res, self.reactor)

self.assertEquals(channel.result["code"], b"200")

Expand Down
22 changes: 7 additions & 15 deletions tests/rest/client/v1/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def create_room_as(self, room_creator, is_public=True, tok=None):
path = path + "?access_token=%s" % tok

request, channel = make_request(
"POST", path, json.dumps(content).encode('utf8')
self.hs.get_reactor(), "POST", path, json.dumps(content).encode('utf8')
)
render(request, self.resource, self.hs.get_reactor())

Expand Down Expand Up @@ -217,7 +217,9 @@ def change_membership(self, room, src, targ, membership, tok=None, expect_code=2

data = {"membership": membership}

request, channel = make_request("PUT", path, json.dumps(data).encode('utf8'))
request, channel = make_request(
self.hs.get_reactor(), "PUT", path, json.dumps(data).encode('utf8')
)

render(request, self.resource, self.hs.get_reactor())

Expand All @@ -228,18 +230,6 @@ def change_membership(self, room, src, targ, membership, tok=None, expect_code=2

self.auth_user_id = temp_id

@defer.inlineCallbacks
def register(self, user_id):
(code, response) = yield self.mock_resource.trigger(
"POST",
"/_matrix/client/r0/register",
json.dumps(
{"user": user_id, "password": "test", "type": "m.login.password"}
),
)
self.assertEquals(200, code)
defer.returnValue(response)

def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200):
if txn_id is None:
txn_id = "m%s" % (str(time.time()))
Expand All @@ -251,7 +241,9 @@ def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200):
if tok:
path = path + "?access_token=%s" % tok

request, channel = make_request("PUT", path, json.dumps(content).encode('utf8'))
request, channel = make_request(
self.hs.get_reactor(), "PUT", path, json.dumps(content).encode('utf8')
)
render(request, self.resource, self.hs.get_reactor())

assert int(channel.result["code"]) == expect_code, (
Expand Down
Loading