This repository has been archived by the owner on Sep 18, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1.8k
/
Copy pathweb_channel.py
53 lines (39 loc) · 1.57 KB
/
web_channel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import asyncio
import websockets
from .base_channel import BaseChannel
from .log_utils import LogType, nni_log
class WebChannel(BaseChannel):
def __init__(self, args):
self.node_id = args.node_id
self.args = args
self.client = None
self.in_cache = b""
super(WebChannel, self).__init__(args)
self._event_loop = None
def _inner_open(self):
url = "ws://{}:{}".format(self.args.nnimanager_ip, self.args.nnimanager_port)
nni_log(LogType.Info, 'WebChannel: connected with info %s' % url)
connect = websockets.connect(url)
self._event_loop = asyncio.get_event_loop()
client = self._event_loop.run_until_complete(connect)
self.client = client
def _inner_close(self):
if self.client is not None:
self.client.close()
if self._event_loop.is_running():
self._event_loop.close()
self.client = None
self._event_loop = None
def _inner_send(self, message):
loop = asyncio.new_event_loop()
loop.run_until_complete(self.client.send(message))
def _inner_receive(self):
messages = []
if self.client is not None:
received = self._event_loop.run_until_complete(self.client.recv())
# receive message is string, to get consistent result, encode it here.
self.in_cache += received.encode("utf8")
messages, self.in_cache = self._fetch_message(self.in_cache)
return messages