Skip to content

Commit

Permalink
Parallel process to send state to Envision server (#567)
Browse files Browse the repository at this point in the history
* Replaced multithreading with multiprocessing and queue.
* Modified test to support multiprocessing Envision client.
  • Loading branch information
Adaickalavan authored Feb 11, 2021
1 parent 1a6703d commit deaab76
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 87 deletions.
10 changes: 6 additions & 4 deletions cli/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved.
#
# MIT License

# Copyright (C) 2021. Huawei Technologies Co., Ltd. All rights reserved.

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#

# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
Expand Down
22 changes: 12 additions & 10 deletions cli/cli.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,31 @@
# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved.
#
# MIT License

# Copyright (C) 2021. Huawei Technologies Co., Ltd. All rights reserved.

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#

# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.

import click

from .envision import envision_cli
from .studio import scenario_cli
from .zoo import zoo_cli
from .ultra import ultra_cli
from cli.envision import envision_cli
from cli.studio import scenario_cli
from cli.ultra import ultra_cli
from cli.zoo import zoo_cli


@click.group()
Expand All @@ -32,9 +35,8 @@ def scl():

scl.add_command(envision_cli)
scl.add_command(scenario_cli)
scl.add_command(zoo_cli)
scl.add_command(ultra_cli)

scl.add_command(zoo_cli)

if __name__ == "__main__":
scl()
123 changes: 59 additions & 64 deletions envision/client.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,42 @@
# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved.
#
# MIT License

# Copyright (C) 2021. Huawei Technologies Co., Ltd. All rights reserved.

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#

# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.

import json
import logging
import multiprocessing
import numpy as np
import re
import time
import uuid
import json
import logging
import threading
import warnings
import websocket

from datetime import datetime
from pathlib import Path
from queue import Queue
from typing import Union
from pathlib import Path
import warnings

import websocket
import numpy as np

from envision import types
from smarts.core.utils.file import unpack
from . import types


class JSONEncoder(json.JSONEncoder):
Expand Down Expand Up @@ -91,28 +94,30 @@ def __init__(
if endpoint is None:
endpoint = "ws://localhost:8081"

self._logging_thread = None
self._logging_queue = Queue()
self._logging_process = None
if output_dir:
self._logging_thread = self._spawn_logging_thread(output_dir, client_id)
self._logging_thread.start()
output_dir = Path(f"{output_dir}/{int(time.time())}")
output_dir.mkdir(parents=True, exist_ok=True)
path = (output_dir / client_id).with_suffix(".jsonl")
self._logging_queue = multiprocessing.Queue()
self._logging_process = multiprocessing.Process(
target=self._write_log_state, args=(self._logging_queue, path,)
)
self._logging_process.daemon = True
self._logging_process.start()

if not self._headless:
self._state_queue = Queue()
self._thread = self._connect(
endpoint=f"{endpoint}/simulations/{client_id}/broadcast",
queue=self._state_queue,
wait_between_retries=wait_between_retries,
self._state_queue = multiprocessing.Queue()
self._process = multiprocessing.Process(
target=self._connect,
args=(
f"{endpoint}/simulations/{client_id}/broadcast",
self._state_queue,
wait_between_retries,
),
)
self._thread.start()

def _spawn_logging_thread(self, output_dir, client_id):
output_dir = Path(f"{output_dir}/{int(time.time())}")
output_dir.mkdir(parents=True, exist_ok=True)
path = (output_dir / client_id).with_suffix(".jsonl")
return threading.Thread(
target=self._write_log_state, args=(self._logging_queue, path), daemon=True
)
self._process.daemon = True
self._process.start()

@staticmethod
def _write_log_state(queue, path):
Expand Down Expand Up @@ -146,9 +151,9 @@ def read_and_send(
logging.info("Finished Envision data replay")

def _connect(
self, endpoint, queue, wait_between_retries: float = 0.05,
self, endpoint, state_queue, wait_between_retries: float = 0.05,
):
threadlocal = threading.local()
connection_established = False

def optionally_serialize_and_write(state: Union[types.State, str], ws):
# if not already serialized
Expand All @@ -170,58 +175,49 @@ def on_error(ws, error):
self._log.error(f"Connection to Envision terminated with: {error}")

def on_open(ws):
setattr(threadlocal, "connection_established", True)
nonlocal connection_established
connection_established = True

while True:
state = queue.get()
state = state_queue.get()
if type(state) is Client.QueueDone:
ws.close()
break

optionally_serialize_and_write(state, ws)

def run_socket(endpoint, wait_between_retries, threadlocal):
connection_established = False

def run_socket(endpoint, wait_between_retries):
nonlocal connection_established
tries = 1
while True:
ws = websocket.WebSocketApp(
endpoint, on_error=on_error, on_close=on_close, on_open=on_open
)
self._log.info("Connected to Envision")

with warnings.catch_warnings():
# XXX: websocket-client library seems to have leaks on connection
# retry that cause annoying warnings within Python 3.8+
warnings.filterwarnings("ignore", category=ResourceWarning)
ws.run_forever()

connection_established = getattr(
threadlocal, "connection_established", False
)

if not connection_established:
self._log.info(f"Attempting to connect to Envision tries={tries}")
self._log.info(f"Attempt {tries} to connect to Envision.")
else:
# when connection closed, retry again every 5 seconds
wait_between_retries = 5
# When connection lost, retry again every 3 seconds
wait_between_retries = 3
self._log.info(
f"Connection to Envision lost. Attempting to reconnect."
f"Connection to Envision lost. Attempt {tries} to reconnect."
)

tries += 1
time.sleep(wait_between_retries)

return threading.Thread(
target=run_socket,
args=(endpoint, wait_between_retries, threadlocal),
daemon=True, # If False, the proc will not terminate until this thread stops
)
run_socket(endpoint, wait_between_retries)

def send(self, state: types.State):
if not self._headless and self._thread.is_alive():
if not self._headless and self._process.is_alive():
self._state_queue.put(state)
if self._logging_thread:
if self._logging_process:
self._logging_queue.put(state)

def _send_raw(self, state: str):
Expand All @@ -233,13 +229,12 @@ def _send_raw(self, state: str):
def teardown(self):
if not self._headless:
self._state_queue.put(Client.QueueDone())

self._logging_queue.put(Client.QueueDone())

if not self._headless and self._thread:
self._thread.join(timeout=3)
self._thread = None

if self._logging_thread:
self._logging_thread.join(timeout=3)
self._logging_thread = None
self._process.join(timeout=3)
self._process = None
self._state_queue.close()

if self._logging_process:
self._logging_queue.put(Client.QueueDone())
self._logging_process.join(timeout=3)
self._logging_process = None
self._logging_queue.close()
24 changes: 15 additions & 9 deletions envision/tests/test_data_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,20 @@
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
import tempfile
from pathlib import Path

import multiprocessing
import pytest
import tempfile
import websocket

from pathlib import Path

from envision.client import Client as Envision
from smarts.core.agent import AgentSpec, Agent
from smarts.core.agent_interface import AgentInterface, AgentType
from smarts.core.scenario import Scenario
from smarts.core.smarts import SMARTS
from smarts.core.sumo_traffic_simulation import SumoTrafficSimulation
from smarts.core.agent import AgentSpec, Agent


AGENT_ID = "Agent-007"
Expand Down Expand Up @@ -67,7 +69,7 @@ def scenarios_iterator():

def fake_websocket_app_class():
# Using a closure instead of a class field to give isolation between tests.
sent = []
sent = multiprocessing.Queue()

class FakeWebSocketApp:
"""Mocks out the websockets.WebSocketApp to intercept send(...) calls and just
Expand All @@ -80,7 +82,7 @@ def __init__(self, endpoint, on_error, on_close, on_open):
self._on_open = on_open

def send(self, data):
sent.append(data)
sent.put(data)
return len(data)

def run_forever(self):
Expand Down Expand Up @@ -117,7 +119,7 @@ def step_through_episodes(agent_spec, smarts, scenarios_iterator):
# Mock WebSocketApp so we can inspect the websocket frames being sent
FakeWebSocketApp, original_sent_data = fake_websocket_app_class()
monkeypatch.setattr(websocket, "WebSocketApp", FakeWebSocketApp)
assert len(original_sent_data) == 0
assert original_sent_data.qsize() == 0

envision = Envision(output_dir=data_replay_path)
smarts = SMARTS(
Expand All @@ -135,15 +137,19 @@ def step_through_episodes(agent_spec, smarts, scenarios_iterator):

jsonl_paths = list(data_replay_run_paths[0].glob("*.jsonl"))
assert len(jsonl_paths) == 1
assert len(original_sent_data) > 0
assert original_sent_data.qsize() > 0

# 2. Inspect replay data

# Mock WebSocketApp so we can inspect the websocket frames being sent
FakeWebSocketApp, new_sent_data = fake_websocket_app_class()
monkeypatch.setattr(websocket, "WebSocketApp", FakeWebSocketApp)
assert len(new_sent_data) == 0
assert new_sent_data.qsize() == 0

# Now read data replay
Envision.read_and_send(jsonl_paths[0], timestep_sec=TIMESTEP_SEC)
assert original_sent_data == new_sent_data

# Verify the new data matches the original data
assert original_sent_data.qsize() == new_sent_data.qsize()
for _ in range(new_sent_data.qsize()):
assert original_sent_data.get() == new_sent_data.get()

0 comments on commit deaab76

Please sign in to comment.