forked from envoyproxy/envoy
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhotrestart_handoff_test.py
505 lines (441 loc) · 18.3 KB
/
hotrestart_handoff_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
"""Tests the behavior of connection handoff between instances during hot restart.
Specifically, tests that:
1. TCP connections opened before hot restart begins continue to function during drain.
2. TCP connections opened after hot restart begins while the old instance is still running
go to the new instance.
TODO(ravenblack): perform the same tests for QUIC connections once they will work as expected.
"""
import abc
import argparse
import asyncio
from functools import cached_property
import logging
import os
import pathlib
import random
import sys
import tempfile
from typing import Awaitable
import unittest
from datetime import datetime, timedelta
from aiohttp import client_exceptions, web, ClientSession
def random_loopback_host():
"""Returns a randomized loopback IP.
This can be used to reduce the chance of port conflicts when tests are
running in parallel."""
return f"127.{random.randrange(0,256)}.{random.randrange(0,256)}.{random.randrange(1, 255)}"
# This is a timeout that must be long enough that the hot restarted
# instance can reliably be fully started up within this many seconds, or the
# test will be flaky. 3 seconds is enough on a not-busy host with a non-tsan
# non-coverage build; 10 seconds should be enough to be not flaky in most
# configurations.
#
# Unfortunately, because the test is verifying the behavior of a connection
# during drain, the first connection must last for the full tolerance duration,
# so increasing this value increases the duration of the test. For this
# reason we want to keep it as low as possible without causing flaky failure.
#
# Ideally this would be adjusted (3x) for tsan and coverage runs, but making that
# possible for python is outside the scope of this test, so we're stuck using the
# 3x value for all tests.
STARTUP_TOLERANCE_SECONDS = 10
# We send multiple requests in parallel and require them all to function correctly
# - this makes it so if something is flaky we're more likely to encounter it, and
# also tests that there's not an "only one" success situation.
PARALLEL_REQUESTS = 5
UPSTREAM_SLOW_PORT = 54321
UPSTREAM_FAST_PORT = 54322
UPSTREAM_HOST = random_loopback_host()
ENVOY_HOST = UPSTREAM_HOST
ENVOY_PORT = 54323
ENVOY_ADMIN_PORT = 54324
# Append process ID to the socket path to minimize chances of
# conflict. We can't use TEST_TMPDIR for this because it makes
# the socket path too long.
SOCKET_PATH = f"@envoy_domain_socket_{os.getpid()}"
SOCKET_MODE = 0
# This log config makes logs interleave with other test output, which
# is useful since with all the async operations it can be hard to figure
# out what's happening.
log = logging.getLogger()
log.level = logging.INFO
_stream_handler = logging.StreamHandler(sys.stdout)
log.addHandler(_stream_handler)
class Upstream:
# This class runs a server which takes an http request to
# path=/ and responds with "start\n" [three second pause] "end\n".
# This allows us to test that during hot restart an already-opened
# connection will persist.
# If initialized with True it will instead respond with
# "fast instance" immediately.
def __init__(self, fast_version=False):
self.port = UPSTREAM_FAST_PORT if fast_version else UPSTREAM_SLOW_PORT
self.app = web.Application()
self.app.add_routes([
web.get("/", self.fast_response) if fast_version else web.get("/", self.slow_response),
])
async def start(self):
self.runner = web.AppRunner(self.app, handle_signals=False)
await self.runner.setup()
site = web.TCPSite(self.runner, host=UPSTREAM_HOST, port=self.port)
await site.start()
async def stop(self):
await self.runner.shutdown()
await self.runner.cleanup()
log.debug("runner cleaned up")
async def fast_response(self, request):
return web.Response(
status=200,
reason="OK",
headers={"content-type": "text/plain"},
body="fast instance",
)
async def slow_response(self, request):
log.debug("slow request received")
response = web.StreamResponse(
status=200, reason="OK", headers={"content-type": "text/plain"})
await response.prepare(request)
await response.write(b"start\n")
await asyncio.sleep(STARTUP_TOLERANCE_SECONDS + 0.5)
await response.write(b"end\n")
await response.write_eof()
return response
class LineGenerator:
@cached_property
def _queue(self) -> asyncio.Queue[str]:
return asyncio.Queue()
@cached_property
def _task(self):
return asyncio.create_task(self.generator())
@abc.abstractmethod
async def generator(self) -> None:
raise NotImplementedError
def __init__(self):
self._task
async def join(self) -> int:
await self._task
return self._queue.qsize()
async def line(self) -> str:
line = await self._queue.get()
self._queue.task_done()
return line
class Http3RequestLineGenerator(LineGenerator):
def __init__(self, url):
self._url = url
super().__init__()
async def generator(self) -> None:
proc = await asyncio.create_subprocess_exec(
IntegrationTest.h3_request,
f"--ca-certs={IntegrationTest.ca_certs}",
self._url,
stdout=asyncio.subprocess.PIPE,
)
async for line in proc.stdout:
await self._queue.put(line)
await proc.wait()
class HttpRequestLineGenerator(LineGenerator):
def __init__(self, url):
self._url = url
super().__init__()
async def generator(self) -> None:
# Separate session per request is against aiohttp idioms, but is
# intentional here because the point of the test is verifying
# where connections go - reusing a connection would do the wrong thing.
async with ClientSession() as session:
async with session.get(self._url) as response:
async for line in response.content:
await self._queue.put(line)
async def _full_http3_request_task(url: str) -> str:
proc = await asyncio.create_subprocess_exec(
IntegrationTest.h3_request,
f"--ca-certs={IntegrationTest.ca_certs}",
url,
stdout=asyncio.subprocess.PIPE,
)
(stdout, _) = await proc.communicate()
await proc.wait()
return stdout.decode("utf-8")
def _full_http3_request(url: str) -> Awaitable[str]:
return asyncio.create_task(_full_http3_request_task(url))
async def _full_http_request_task(url: str) -> str:
# Separate session per request is against aiohttp idioms, but is
# intentional here because the point of the test is verifying
# where connections go - reusing a connection would do the wrong thing.
async with ClientSession() as session:
async with session.get(url) as response:
return await response.text()
def _full_http_request(url: str) -> Awaitable[str]:
return asyncio.create_task(_full_http_request_task(url))
def filter_chains(codec_type: str = "AUTO") -> str:
return f"""
filter_chains:
- filters:
- name: envoy.filters.network.http_connection_manager
typed_config:
"@type": type.googleapis.com/envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager
stat_prefix: ingress_http
codec_type: {codec_type}
route_config:
name: local_route
virtual_hosts:
- name: local_service
domains: ["*"]
routes:
- match:
prefix: "/"
route:
cluster: some_service
http_filters:
- name: envoy.filters.http.router
typed_config:
"@type": type.googleapis.com/envoy.extensions.filters.http.router.v3.Router
"""
def _make_envoy_config_yaml(upstream_port: int, file_path: pathlib.Path):
file_path.write_text(
f"""
admin:
address:
socket_address:
address: {ENVOY_HOST}
port_value: {ENVOY_ADMIN_PORT}
static_resources:
listeners:
- name: listener_quic
address:
socket_address:
protocol: UDP
address: {ENVOY_HOST}
port_value: {ENVOY_PORT}
{filter_chains("HTTP3")}
transport_socket:
name: "envoy.transport_sockets.quic"
typed_config:
"@type": "type.googleapis.com/envoy.extensions.transport_sockets.quic.v3.QuicDownstreamTransport"
downstream_tls_context:
common_tls_context:
tls_certificates:
- certificate_chain:
filename: "{IntegrationTest.server_cert}"
private_key:
filename: "{IntegrationTest.server_key}"
udp_listener_config:
quic_options: {"{}"}
downstream_socket_config:
prefer_gro: true
- name: listener_http
address:
socket_address:
address: {ENVOY_HOST}
port_value: {ENVOY_PORT}
{filter_chains()}
clusters:
- name: some_service
connect_timeout: 0.25s
type: STATIC
lb_policy: ROUND_ROBIN
load_assignment:
cluster_name: some_service
endpoints:
- lb_endpoints:
- endpoint:
address:
socket_address:
address: {UPSTREAM_HOST}
port_value: {upstream_port}
""")
async def _wait_for_envoy_epoch(i: int):
"""Load the admin/server_info page until restart_epoch is i, or timeout"""
expected_substring = f'"restart_epoch": {i}'
deadline = datetime.now() + timedelta(seconds=STARTUP_TOLERANCE_SECONDS)
response = "admin port not responding within timeout"
while datetime.now() < deadline:
try:
response = await _full_http_request(
f"http://{ENVOY_HOST}:{ENVOY_ADMIN_PORT}/server_info")
if expected_substring in response:
return
except client_exceptions.ClientConnectorError:
pass
await asyncio.sleep(0.2)
# Envoy instance with expected restart_epoch should have started up
assert expected_substring in response, f"server_info={response}"
class IntegrationTest(unittest.IsolatedAsyncioTestCase):
server_cert: pathlib.Path
server_key: pathlib.Path
ca_certs: pathlib.Path
h3_request: pathlib.Path
envoy_binary: pathlib.Path
async def asyncSetUp(self) -> None:
print(os.environ)
tmpdir = os.environ["TEST_TMPDIR"]
self.slow_config_path = pathlib.Path(tmpdir, "slow_config.yaml")
self.fast_config_path = pathlib.Path(tmpdir, "fast_config.yaml")
self.base_id_path = pathlib.Path(tmpdir, "base_id.txt")
_make_envoy_config_yaml(upstream_port=UPSTREAM_SLOW_PORT, file_path=self.slow_config_path)
_make_envoy_config_yaml(upstream_port=UPSTREAM_FAST_PORT, file_path=self.fast_config_path)
self.base_envoy_args = [
IntegrationTest.envoy_binary,
"--socket-path",
SOCKET_PATH,
"--socket-mode",
str(SOCKET_MODE),
]
log.info(f"starting upstreams on https://{ENVOY_HOST}:{ENVOY_PORT}/")
await super().asyncSetUp()
self.slow_upstream = Upstream()
await self.slow_upstream.start()
self.fast_upstream = Upstream(True)
await self.fast_upstream.start()
async def asyncTearDown(self) -> None:
await self.slow_upstream.stop()
await self.fast_upstream.stop()
return await super().asyncTearDown()
async def test_connection_handoffs(self) -> None:
log.info("starting envoy")
envoy_process_1 = await asyncio.create_subprocess_exec(
*self.base_envoy_args,
"--restart-epoch",
"0",
"--use-dynamic-base-id",
"--base-id-path",
self.base_id_path,
"-c",
self.slow_config_path,
)
log.info(f"cert path = {IntegrationTest.server_cert}")
log.info("waiting for envoy ready")
await _wait_for_envoy_epoch(0)
log.info("making requests")
request_url = f"http://{ENVOY_HOST}:{ENVOY_PORT}/"
srequest_url = f"https://{ENVOY_HOST}:{ENVOY_PORT}/"
slow_responses = [
HttpRequestLineGenerator(request_url) for i in range(PARALLEL_REQUESTS)
] + [Http3RequestLineGenerator(srequest_url) for i in range(PARALLEL_REQUESTS)]
log.info("waiting for responses to begin")
for response in slow_responses:
self.assertEqual(await response.line(), b"start\n")
base_id = int(self.base_id_path.read_text())
log.info(f"starting envoy hot restart for base id {base_id}")
envoy_process_2 = await asyncio.create_subprocess_exec(
*self.base_envoy_args,
"--restart-epoch",
"1",
"--parent-shutdown-time-s",
str(STARTUP_TOLERANCE_SECONDS + 1),
"--base-id",
str(base_id),
"-c",
self.fast_config_path,
)
log.info("waiting for new envoy instance to begin")
await _wait_for_envoy_epoch(1)
log.info("sending request to fast upstream")
fast_responses = [_full_http_request(request_url) for i in range(PARALLEL_REQUESTS)
] + [_full_http3_request(srequest_url) for i in range(PARALLEL_REQUESTS)]
for response in fast_responses:
self.assertEqual(
await response,
"fast instance",
"new requests after hot restart begins should go to new cluster",
)
# Now wait for the slow request to complete, and make sure it still gets the
# response from the old instance.
log.info("waiting for completion of original slow request")
t1 = datetime.now()
for response in slow_responses:
self.assertEqual(await response.line(), b"end\n")
t2 = datetime.now()
self.assertGreater(
(t2 - t1).total_seconds(),
0.5,
"slow request should be incomplete when the test waits for it, otherwise the test is not necessarily validating during-drain behavior",
)
for response in slow_responses:
self.assertEquals(await response.join(), 0)
log.info("waiting for parent instance to terminate")
await envoy_process_1.wait()
log.info("sending second request to fast upstream")
fast_responses = [_full_http_request(request_url) for i in range(PARALLEL_REQUESTS)
] + [_full_http3_request(srequest_url) for i in range(PARALLEL_REQUESTS)]
for response in fast_responses:
self.assertEqual(
await response,
"fast instance",
"new requests after old instance terminates should go to new cluster",
)
log.info("shutting child instance down")
envoy_process_2.terminate()
await envoy_process_2.wait()
def generate_server_cert(
ca_key_path: pathlib.Path,
ca_cert_path: pathlib.Path) -> "tuple[pathlib.Path, pathlib.Path]":
"""Generates a temporary key and cert pem file and returns the paths.
This is necessary because the http3 client validates that the server
certificate matches the host of the request, and our host is an
arbitrary randomized 127.x.y.z IP address to reduce the likelihood
of port collisions during testing. We therefore must use a generated
certificate that really matches the host IP.
"""
from cryptography import x509
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from ipaddress import ip_address
with open(ca_key_path, "rb") as ca_key_file:
ca_key = serialization.load_pem_private_key(
ca_key_file.read(),
password=None,
)
with open(ca_cert_path, "rb") as ca_cert_file:
ca_cert = x509.load_pem_x509_certificate(ca_cert_file.read())
key = rsa.generate_private_key(
public_exponent=65537,
key_size=2048,
backend=default_backend(),
)
hostname = "testhost"
name = x509.Name([x509.NameAttribute(x509.oid.NameOID.COMMON_NAME, hostname)])
alt_names = [x509.DNSName(hostname)]
alt_names.append(x509.IPAddress(ip_address(ENVOY_HOST)))
san = x509.SubjectAlternativeName(alt_names)
basic_constraints = x509.BasicConstraints(ca=True, path_length=0)
now = datetime.utcnow()
cert = (
x509.CertificateBuilder() # Comment to keep linter from uglifying!
.subject_name(name).issuer_name(ca_cert.subject).public_key(key.public_key()).serial_number(
1).not_valid_before(now).not_valid_after(now + timedelta(days=30)).add_extension(
basic_constraints,
False).add_extension(san, False).sign(ca_key, hashes.SHA256(), default_backend()))
cert_pem = cert.public_bytes(encoding=serialization.Encoding.PEM)
key_pem = key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption(),
)
cert_file = tempfile.NamedTemporaryFile(
suffix="_key.pem", delete=False, dir=os.environ["TEST_TMPDIR"])
cert_file.write(cert_pem)
cert_file.close()
key_file = tempfile.NamedTemporaryFile(
suffix="_cert.pem", delete=False, dir=os.environ["TEST_TMPDIR"])
key_file.write(key_pem)
key_file.close()
return key_file.name, cert_file.name
def main():
parser = argparse.ArgumentParser(description="Hot restart handoff test")
parser.add_argument("--envoy-binary", type=str, required=True)
parser.add_argument("--h3-request", type=str, required=True)
parser.add_argument("--ca-certs", type=str, required=True)
parser.add_argument("--ca-key", type=str, required=True)
# unittest also parses some args, so we strip out the ones we're using
# and leave the rest for unittest to consume.
(args, sys.argv[1:]) = parser.parse_known_args()
(IntegrationTest.server_key,
IntegrationTest.server_cert) = generate_server_cert(args.ca_key, args.ca_certs)
IntegrationTest.ca_certs = args.ca_certs
IntegrationTest.h3_request = args.h3_request
IntegrationTest.envoy_binary = args.envoy_binary
unittest.main()
if __name__ == "__main__":
main()