Skip to content

Commit 15e2521

Browse files
committed
update tests (core)
1 parent 8938731 commit 15e2521

File tree

2 files changed

+47
-69
lines changed

2 files changed

+47
-69
lines changed

sdk/core/azure-core/tests/async_tests/test_challenge_authentication_async.py

Lines changed: 20 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@
2424
#
2525
# --------------------------------------------------------------------------
2626
import asyncio
27-
import base64
28-
import itertools
2927
import time
3028
from unittest.mock import Mock
3129

@@ -40,16 +38,6 @@
4038
pytestmark = pytest.mark.asyncio
4139

4240

43-
class MockPolicy(AsyncChallengeAuthenticationPolicy):
44-
def __init__(self, *args, **kwargs):
45-
super(MockPolicy, self).__init__(*args, **kwargs)
46-
self.on_challenge_called = False
47-
48-
async def on_challenge(self, request, response, challenge):
49-
self.on_challenge_called = True
50-
return False
51-
52-
5341
async def test_adds_header():
5442
"""The bearer token policy should add a header containing a token from its credential"""
5543
# 2524608000 == 01/01/2050 @ 12:00am (UTC)
@@ -67,7 +55,7 @@ async def get_token(_):
6755
return expected_token
6856

6957
fake_credential = Mock(get_token=get_token)
70-
policies = [MockPolicy(fake_credential, "scope"), Mock(send=verify_authorization_header)]
58+
policies = [AsyncChallengeAuthenticationPolicy(fake_credential, "scope"), Mock(send=verify_authorization_header)]
7159
pipeline = AsyncPipeline(transport=Mock(), policies=policies)
7260

7361
await pipeline.run(HttpRequest("GET", "https://localhost"), context=None)
@@ -89,7 +77,7 @@ async def get_token(_):
8977

9078
expected_scope = "scope"
9179
credential = Mock(get_token=Mock(wraps=get_token))
92-
policy = MockPolicy(credential, expected_scope)
80+
policy = AsyncChallengeAuthenticationPolicy(credential, expected_scope)
9381
pipeline = AsyncPipeline(transport=Mock(send=send), policies=[policy])
9482

9583
await pipeline.run(HttpRequest("GET", "https://localhost"))
@@ -107,7 +95,7 @@ async def verify_request(request):
10795
return expected_response
10896

10997
fake_credential = Mock(get_token=lambda *_, **__: get_completed_future(AccessToken("", 0)))
110-
policies = [MockPolicy(fake_credential, "scope"), Mock(send=verify_request)]
98+
policies = [AsyncChallengeAuthenticationPolicy(fake_credential, "scope"), Mock(send=verify_request)]
11199
response = await AsyncPipeline(transport=Mock(), policies=policies).run(expected_request)
112100

113101
assert response is expected_response
@@ -130,7 +118,7 @@ async def send(_):
130118

131119
transport = Mock(send=send)
132120

133-
pipeline = AsyncPipeline(transport=transport, policies=[MockPolicy(credential, "scope")])
121+
pipeline = AsyncPipeline(transport=transport, policies=[AsyncChallengeAuthenticationPolicy(credential, "scope")])
134122
await pipeline.run(HttpRequest("GET", "https://localhost"))
135123
assert get_token_calls == 1 # policy has no token at first request -> it should call get_token
136124
await pipeline.run(HttpRequest("GET", "https://localhost"))
@@ -139,7 +127,7 @@ async def send(_):
139127
expired_token = AccessToken("token", time.time())
140128
get_token_calls = 0
141129
expected_token = expired_token
142-
pipeline = AsyncPipeline(transport=transport, policies=[MockPolicy(credential, "scope")])
130+
pipeline = AsyncPipeline(transport=transport, policies=[AsyncChallengeAuthenticationPolicy(credential, "scope")])
143131

144132
await pipeline.run(HttpRequest("GET", "https://localhost"))
145133
assert get_token_calls == 1
@@ -151,25 +139,25 @@ async def test_optionally_enforces_https():
151139
"""HTTPS enforcement should be controlled by a keyword argument, and enabled by default"""
152140

153141
async def assert_option_popped(request, **kwargs):
154-
assert "enforce_https" not in kwargs, "MockPolicy didn't pop the 'enforce_https' option"
142+
assert "enforce_https" not in kwargs, "AsyncChallengeAuthenticationPolicy didn't pop the 'enforce_https' option"
155143
return Mock()
156144

157145
credential = Mock(get_token=lambda *_, **__: get_completed_future(AccessToken("***", 42)))
158-
pipeline = AsyncPipeline(transport=Mock(send=assert_option_popped), policies=[MockPolicy(credential, "scope")])
146+
pipeline = AsyncPipeline(transport=Mock(send=assert_option_popped), policies=[AsyncChallengeAuthenticationPolicy(credential, "scope")])
159147

160148
# by default and when enforce_https=True, the policy should raise when given an insecure request
161149
with pytest.raises(ServiceRequestError):
162-
await pipeline.run(HttpRequest("GET", "http://not.secure"))
150+
await pipeline.run(HttpRequest("GET", "http://localhost"))
163151
with pytest.raises(ServiceRequestError):
164-
await pipeline.run(HttpRequest("GET", "http://not.secure"), enforce_https=True)
152+
await pipeline.run(HttpRequest("GET", "http://localhost"), enforce_https=True)
165153

166154
# when enforce_https=False, an insecure request should pass
167-
await pipeline.run(HttpRequest("GET", "http://not.secure"), enforce_https=False)
155+
await pipeline.run(HttpRequest("GET", "http://localhost"), enforce_https=False)
168156

169157
# https requests should always pass
170-
await pipeline.run(HttpRequest("GET", "https://secure"), enforce_https=False)
171-
await pipeline.run(HttpRequest("GET", "https://secure"), enforce_https=True)
172-
await pipeline.run(HttpRequest("GET", "https://secure"))
158+
await pipeline.run(HttpRequest("GET", "https://localhost"), enforce_https=False)
159+
await pipeline.run(HttpRequest("GET", "https://localhost"), enforce_https=True)
160+
await pipeline.run(HttpRequest("GET", "https://localhost"))
173161

174162

175163
async def test_preserves_enforce_https_opt_out():
@@ -186,10 +174,10 @@ async def send(_):
186174

187175
get_token = get_completed_future(AccessToken("***", 42))
188176
credential = Mock(get_token=lambda *_, **__: get_token)
189-
policies = [MockPolicy(credential, "scope"), ContextValidator()]
177+
policies = [AsyncChallengeAuthenticationPolicy(credential, "scope"), ContextValidator()]
190178
pipeline = AsyncPipeline(transport=transport, policies=policies)
191179

192-
await pipeline.run(HttpRequest("GET", "http://not.secure"), enforce_https=False)
180+
await pipeline.run(HttpRequest("GET", "http://localhost"), enforce_https=False)
193181

194182

195183
async def test_context_unmodified_by_default():
@@ -206,10 +194,10 @@ async def send(_):
206194

207195
get_token = get_completed_future(AccessToken("***", 42))
208196
credential = Mock(get_token=lambda *_, **__: get_token)
209-
policies = [MockPolicy(credential, "scope"), ContextValidator()]
197+
policies = [AsyncChallengeAuthenticationPolicy(credential, "scope"), ContextValidator()]
210198
pipeline = AsyncPipeline(transport=transport, policies=policies)
211199

212-
await pipeline.run(HttpRequest("GET", "https://secure"))
200+
await pipeline.run(HttpRequest("GET", "https://localhost"))
213201

214202

215203
async def test_cannot_complete_challenge():
@@ -225,12 +213,13 @@ async def send(_):
225213
expected_scope = "scope"
226214
get_token = Mock(return_value=get_completed_future(AccessToken("***", 42)))
227215
credential = Mock(get_token=get_token)
228-
policy = MockPolicy(credential, expected_scope)
216+
policy = AsyncChallengeAuthenticationPolicy(credential, expected_scope)
217+
policy.on_challenge = Mock(wraps=policy.on_challenge)
229218

230219
pipeline = AsyncPipeline(transport=transport, policies=[policy])
231220
response = await pipeline.run(HttpRequest("GET", "https://localhost"))
232221

233-
assert policy.on_challenge_called
222+
assert policy.on_challenge.called
234223
assert response.http_response is expected_response
235224
assert transport.send.call_count == 1
236225
credential.get_token.assert_called_once_with(expected_scope)

sdk/core/azure-core/tests/test_challenge_authentication.py

Lines changed: 27 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@
2323
# THE SOFTWARE.
2424
#
2525
# --------------------------------------------------------------------------
26-
import base64
27-
import itertools
2826
import time
2927

3028
from azure.core.credentials import AccessToken
@@ -42,16 +40,6 @@
4240
from mock import Mock
4341

4442

45-
class MockPolicy(ChallengeAuthenticationPolicy):
46-
def __init__(self, *args, **kwargs):
47-
super(MockPolicy, self).__init__(*args, **kwargs)
48-
self.on_challenge_called = False
49-
50-
def on_challenge(self, request, response, challenge):
51-
self.on_challenge_called = True
52-
return False
53-
54-
5543
def test_adds_header():
5644
"""The policy should add a header containing a token from its credential"""
5745
# 2524608000 == 01/01/2050 @ 12:00am (UTC)
@@ -62,15 +50,15 @@ def verify_authorization_header(request):
6250
return Mock()
6351

6452
fake_credential = Mock(get_token=Mock(return_value=expected_token))
65-
policy = MockPolicy(fake_credential, "scope")
53+
policy = ChallengeAuthenticationPolicy(fake_credential, "scope")
6654
policies = [policy, Mock(send=verify_authorization_header)]
6755

6856
pipeline = Pipeline(transport=Mock(), policies=policies)
69-
pipeline.run(HttpRequest("GET", "https://spam.eggs"))
57+
pipeline.run(HttpRequest("GET", "https://localhost"))
7058

7159
assert fake_credential.get_token.call_count == 1
7260

73-
pipeline.run(HttpRequest("GET", "https://spam.eggs"))
61+
pipeline.run(HttpRequest("GET", "https://localhost"))
7462

7563
# Didn't need a new token
7664
assert fake_credential.get_token.call_count == 1
@@ -81,7 +69,7 @@ def test_default_context():
8169
expected_scope = "scope"
8270
token = AccessToken("", 0)
8371
credential = Mock(get_token=Mock(return_value=token))
84-
policy = MockPolicy(credential, expected_scope)
72+
policy = ChallengeAuthenticationPolicy(credential, expected_scope)
8573
pipeline = Pipeline(transport=Mock(), policies=[policy])
8674

8775
pipeline.run(HttpRequest("GET", "https://localhost"))
@@ -91,16 +79,16 @@ def test_default_context():
9179

9280
def test_send():
9381
"""The policy should invoke the next policy's send method and return the result"""
94-
expected_request = HttpRequest("GET", "https://spam.eggs")
82+
expected_request = HttpRequest("GET", "https://localhost")
9583
expected_response = Mock()
9684

9785
def verify_request(request):
9886
assert request.http_request is expected_request
9987
return expected_response
10088

10189
fake_credential = Mock(get_token=lambda _: AccessToken("", 0))
102-
policy = MockPolicy(fake_credential, "scope")
103-
policies = [MockPolicy(fake_credential, "scope"), Mock(send=verify_request)]
90+
policy = ChallengeAuthenticationPolicy(fake_credential, "scope")
91+
policies = [policy, Mock(send=verify_request)]
10492
response = Pipeline(transport=Mock(), policies=policies).run(expected_request)
10593

10694
assert response is expected_response
@@ -109,24 +97,24 @@ def verify_request(request):
10997
def test_token_caching():
11098
good_for_one_hour = AccessToken("token", time.time() + 3600)
11199
credential = Mock(get_token=Mock(return_value=good_for_one_hour))
112-
policy = MockPolicy(credential, "scope")
100+
policy = ChallengeAuthenticationPolicy(credential, "scope")
113101
pipeline = Pipeline(transport=Mock(), policies=[policy])
114102

115-
pipeline.run(HttpRequest("GET", "https://spam.eggs"))
103+
pipeline.run(HttpRequest("GET", "https://localhost"))
116104
assert credential.get_token.call_count == 1 # policy has no token at first request -> it should call get_token
117105

118-
pipeline.run(HttpRequest("GET", "https://spam.eggs"))
106+
pipeline.run(HttpRequest("GET", "https://localhost"))
119107
assert credential.get_token.call_count == 1 # token is good for an hour -> policy should return it from cache
120108

121109
expired_token = AccessToken("token", time.time())
122110
credential.get_token.reset_mock()
123111
credential.get_token.return_value = expired_token
124-
pipeline = Pipeline(transport=Mock(), policies=[MockPolicy(credential, "scope")])
112+
pipeline = Pipeline(transport=Mock(), policies=[ChallengeAuthenticationPolicy(credential, "scope")])
125113

126-
pipeline.run(HttpRequest("GET", "https://spam.eggs"))
114+
pipeline.run(HttpRequest("GET", "https://localhost"))
127115
assert credential.get_token.call_count == 1
128116

129-
pipeline.run(HttpRequest("GET", "https://spam.eggs"))
117+
pipeline.run(HttpRequest("GET", "https://localhost"))
130118
assert credential.get_token.call_count == 2 # token expired -> policy should call get_token
131119

132120

@@ -138,22 +126,22 @@ def assert_option_popped(request, **kwargs):
138126
return Mock()
139127

140128
credential = Mock(get_token=lambda *_, **__: AccessToken("***", 42))
141-
policy = MockPolicy(credential, "scope")
129+
policy = ChallengeAuthenticationPolicy(credential, "scope")
142130
pipeline = Pipeline(transport=Mock(send=assert_option_popped), policies=[policy])
143131

144132
# by default and when enforce_https=True, the policy should raise when given an insecure request
145133
with pytest.raises(ServiceRequestError):
146-
pipeline.run(HttpRequest("GET", "http://not.secure"))
134+
pipeline.run(HttpRequest("GET", "http://localhost"))
147135
with pytest.raises(ServiceRequestError):
148-
pipeline.run(HttpRequest("GET", "http://not.secure"), enforce_https=True)
136+
pipeline.run(HttpRequest("GET", "http://localhost"), enforce_https=True)
149137

150138
# when enforce_https=False, an insecure request should pass
151-
pipeline.run(HttpRequest("GET", "http://not.secure"), enforce_https=False)
139+
pipeline.run(HttpRequest("GET", "http://localhost"), enforce_https=False)
152140

153141
# https requests should always pass
154-
pipeline.run(HttpRequest("GET", "https://secure"), enforce_https=False)
155-
pipeline.run(HttpRequest("GET", "https://secure"), enforce_https=True)
156-
pipeline.run(HttpRequest("GET", "https://secure"))
142+
pipeline.run(HttpRequest("GET", "https://localhost"), enforce_https=False)
143+
pipeline.run(HttpRequest("GET", "https://localhost"), enforce_https=True)
144+
pipeline.run(HttpRequest("GET", "https://localhost"))
157145

158146

159147
def test_preserves_enforce_https_opt_out():
@@ -165,10 +153,10 @@ def on_request(self, request):
165153
return Mock()
166154

167155
credential = Mock(get_token=Mock(return_value=AccessToken("***", 42)))
168-
policy = MockPolicy(credential, "scope")
156+
policy = ChallengeAuthenticationPolicy(credential, "scope")
169157
pipeline = Pipeline(transport=Mock(), policies=[policy])
170158

171-
pipeline.run(HttpRequest("GET", "http://not.secure"), enforce_https=False)
159+
pipeline.run(HttpRequest("GET", "http://localhost"), enforce_https=False)
172160

173161

174162
def test_context_unmodified_by_default():
@@ -179,11 +167,11 @@ def on_request(self, request):
179167
assert not any(request.context), "the policy shouldn't add to the request's context"
180168

181169
credential = Mock(get_token=Mock(return_value=AccessToken("***", 42)))
182-
policy = MockPolicy(credential, "scope")
170+
policy = ChallengeAuthenticationPolicy(credential, "scope")
183171
policies = [policy, ContextValidator()]
184172
pipeline = Pipeline(transport=Mock(), policies=policies)
185173

186-
pipeline.run(HttpRequest("GET", "https://secure"))
174+
pipeline.run(HttpRequest("GET", "https://localhost"))
187175

188176

189177
def test_cannot_complete_challenge():
@@ -194,12 +182,13 @@ def test_cannot_complete_challenge():
194182
credential = Mock(get_token=Mock(return_value=expected_token))
195183
expected_response = Mock(status_code=401, headers={"WWW-Authenticate": 'Basic realm="localhost"'})
196184
transport = Mock(send=Mock(return_value=expected_response))
197-
policy = MockPolicy(credential, expected_scope)
185+
policy = ChallengeAuthenticationPolicy(credential, expected_scope)
186+
policy.on_challenge = Mock(wraps=policy.on_challenge)
198187

199188
pipeline = Pipeline(transport=transport, policies=[policy])
200189
response = pipeline.run(HttpRequest("GET", "https://localhost"))
201190

202-
assert policy.on_challenge_called
191+
assert policy.on_challenge.called
203192
assert response.http_response is expected_response
204193
assert transport.send.call_count == 1
205194
credential.get_token.assert_called_once_with(expected_scope)

0 commit comments

Comments
 (0)