diff --git a/tests/integration/defs/accuracy/test_disaggregated_serving.py b/tests/integration/defs/accuracy/test_disaggregated_serving.py index 7277d5f47c5..c9f94aa3fbe 100644 --- a/tests/integration/defs/accuracy/test_disaggregated_serving.py +++ b/tests/integration/defs/accuracy/test_disaggregated_serving.py @@ -67,6 +67,41 @@ def __exit__(self, exc_type, exc_val, exc_tb): return False +def check_port_available(port: int) -> int: + import socket + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(('localhost', port)) + return port + except socket.error: + # find a free port + sock = socket.socket() + sock.bind(('', 0)) + return sock.getsockname()[1] + + +def revise_disaggregated_server_config_urls_with_free_ports( + disaggregated_server_config: Dict[str, Any]) -> Dict[str, Any]: + disaggregated_server_config['port'] = check_port_available( + disaggregated_server_config['port']) + ctx_urls = disaggregated_server_config["context_servers"]["urls"] + gen_urls = disaggregated_server_config["generation_servers"]["urls"] + + new_ctx_urls = [] + new_gen_urls = [] + for url in ctx_urls: + port = check_port_available(int(url.split(":")[1])) + new_ctx_urls.append(f"localhost:{port}") + for url in gen_urls: + port = check_port_available(int(url.split(":")[1])) + new_gen_urls.append(f"localhost:{port}") + + disaggregated_server_config["context_servers"]["urls"] = new_ctx_urls + disaggregated_server_config["generation_servers"]["urls"] = new_gen_urls + + return disaggregated_server_config + + @contextlib.contextmanager def launch_disaggregated_llm( disaggregated_server_config: Dict[str, Any], @@ -87,6 +122,9 @@ def launch_disaggregated_llm( f"Using unified tp parameter for testing is not recommended. Please use server configs instead." ) + disaggregated_server_config = revise_disaggregated_server_config_urls_with_free_ports( + disaggregated_server_config) + with open(disaggregated_serving_config_path, "w") as f: yaml.dump(disaggregated_server_config, f) ctx_server_config_path = os.path.join(temp_dir.name, @@ -138,6 +176,7 @@ def launch_disaggregated_llm( ctx_urls = disaggregated_server_config["context_servers"]["urls"] gen_urls = disaggregated_server_config["generation_servers"]["urls"] + serve_port = disaggregated_server_config["port"] ctx_ports = [int(url.split(":")[1]) for url in ctx_urls] gen_ports = [int(url.split(":")[1]) for url in gen_urls] @@ -236,14 +275,14 @@ def multi_popen(server_configs, server_name="", enable_redirect_log=False): ) try: print("Checking health endpoint") - response = requests.get("http://localhost:8000/health") + response = requests.get(f"http://localhost:{serve_port}/health") if response.status_code == 200: break except requests.exceptions.ConnectionError: continue client = openai.OpenAI(api_key="1234567890", - base_url=f"http://localhost:8000/v1", + base_url=f"http://localhost:{serve_port}/v1", timeout=1800000) def send_request(prompt: str, sampling_params: SamplingParams,