Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 41 additions & 2 deletions tests/integration/defs/accuracy/test_disaggregated_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,41 @@ def __exit__(self, exc_type, exc_val, exc_tb):
return False


def check_port_available(port: int) -> int:
import socket
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('localhost', port))
return port
except socket.error:
# find a free port
sock = socket.socket()
sock.bind(('', 0))
return sock.getsockname()[1]


def revise_disaggregated_server_config_urls_with_free_ports(
disaggregated_server_config: Dict[str, Any]) -> Dict[str, Any]:
disaggregated_server_config['port'] = check_port_available(
disaggregated_server_config['port'])
ctx_urls = disaggregated_server_config["context_servers"]["urls"]
gen_urls = disaggregated_server_config["generation_servers"]["urls"]

new_ctx_urls = []
new_gen_urls = []
for url in ctx_urls:
port = check_port_available(int(url.split(":")[1]))
new_ctx_urls.append(f"localhost:{port}")
for url in gen_urls:
port = check_port_available(int(url.split(":")[1]))
new_gen_urls.append(f"localhost:{port}")

disaggregated_server_config["context_servers"]["urls"] = new_ctx_urls
disaggregated_server_config["generation_servers"]["urls"] = new_gen_urls

return disaggregated_server_config


@contextlib.contextmanager
def launch_disaggregated_llm(
disaggregated_server_config: Dict[str, Any],
Expand All @@ -87,6 +122,9 @@ def launch_disaggregated_llm(
f"Using unified tp parameter for testing is not recommended. Please use server configs instead."
)

disaggregated_server_config = revise_disaggregated_server_config_urls_with_free_ports(
disaggregated_server_config)

with open(disaggregated_serving_config_path, "w") as f:
yaml.dump(disaggregated_server_config, f)
ctx_server_config_path = os.path.join(temp_dir.name,
Expand Down Expand Up @@ -138,6 +176,7 @@ def launch_disaggregated_llm(
ctx_urls = disaggregated_server_config["context_servers"]["urls"]
gen_urls = disaggregated_server_config["generation_servers"]["urls"]

serve_port = disaggregated_server_config["port"]
ctx_ports = [int(url.split(":")[1]) for url in ctx_urls]
gen_ports = [int(url.split(":")[1]) for url in gen_urls]

Expand Down Expand Up @@ -236,14 +275,14 @@ def multi_popen(server_configs, server_name="", enable_redirect_log=False):
)
try:
print("Checking health endpoint")
response = requests.get("http://localhost:8000/health")
response = requests.get(f"http://localhost:{serve_port}/health")
if response.status_code == 200:
break
except requests.exceptions.ConnectionError:
continue

client = openai.OpenAI(api_key="1234567890",
base_url=f"http://localhost:8000/v1",
base_url=f"http://localhost:{serve_port}/v1",
timeout=1800000)

def send_request(prompt: str, sampling_params: SamplingParams,
Expand Down