|
14 | 14 | # limitations under the License. |
15 | 15 | """ |
16 | 16 |
|
17 | | -import os |
18 | | -import threading |
19 | | -import time |
20 | | -import traceback |
| 17 | +from abc import ABC, abstractmethod |
21 | 18 |
|
22 | | -import msgpack |
23 | 19 | import zmq |
24 | 20 |
|
25 | | -from fastdeploy import envs |
26 | | -from fastdeploy.utils import zmq_client_logger |
| 21 | +from fastdeploy.utils import llm_logger |
27 | 22 |
|
28 | 23 |
|
29 | | -class ZmqClient: |
| 24 | +class ZmqClientBase(ABC): |
30 | 25 | """ |
31 | | - ZmqClient is a class that provides a client-side interface for sending and receiving messages using ZeroMQ. |
| 26 | + ZmqClientBase is a base class that provides a client-side interface for sending and receiving messages using ZeroMQ. |
32 | 27 | """ |
33 | 28 |
|
34 | | - def __init__(self, name, mode): |
35 | | - self.context = zmq.Context(4) |
36 | | - self.socket = self.context.socket(mode) |
37 | | - self.file_name = f"/dev/shm/{name}.socket" |
38 | | - self.router_path = f"/dev/shm/router_{name}.ipc" |
| 29 | + def __init__(self): |
| 30 | + pass |
39 | 31 |
|
40 | | - self.ZMQ_SNDHWM = int(envs.FD_ZMQ_SNDHWM) |
41 | | - self.aggregate_send = envs.FD_USE_AGGREGATE_SEND |
| 32 | + @abstractmethod |
| 33 | + def _create_socket(self): |
| 34 | + """Abstract method to create and return a ZeroMQ socket.""" |
| 35 | + pass |
42 | 36 |
|
43 | | - self.mutex = threading.Lock() |
44 | | - self.req_dict = dict() |
45 | | - self.router = None |
46 | | - self.poller = None |
47 | | - self.running = True |
| 37 | + def _ensure_socket(self): |
| 38 | + """Ensure the socket is created before use.""" |
| 39 | + if self.socket is None: |
| 40 | + self.socket = self._create_socket() |
48 | 41 |
|
| 42 | + @abstractmethod |
49 | 43 | def connect(self): |
50 | 44 | """ |
51 | 45 | Connect to the server using the file name specified in the constructor. |
52 | 46 | """ |
53 | | - self.socket.connect(f"ipc://{self.file_name}") |
54 | | - |
55 | | - def start_server(self): |
56 | | - """ |
57 | | - Start the server using the file name specified in the constructor. |
58 | | - """ |
59 | | - self.socket.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM) |
60 | | - self.socket.setsockopt(zmq.SNDTIMEO, -1) |
61 | | - self.socket.bind(f"ipc://{self.file_name}") |
62 | | - self.poller = zmq.Poller() |
63 | | - self.poller.register(self.socket, zmq.POLLIN) |
64 | | - |
65 | | - def create_router(self): |
66 | | - """ |
67 | | - Create a ROUTER socket and bind it to the specified router path. |
68 | | - """ |
69 | | - self.router = self.context.socket(zmq.ROUTER) |
70 | | - self.router.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM) |
71 | | - self.router.setsockopt(zmq.ROUTER_MANDATORY, 1) |
72 | | - self.router.setsockopt(zmq.SNDTIMEO, -1) |
73 | | - self.router.bind(f"ipc://{self.router_path}") |
74 | | - zmq_client_logger.info(f"router path: {self.router_path}") |
| 47 | + pass |
75 | 48 |
|
76 | 49 | def send_json(self, data): |
77 | 50 | """ |
78 | 51 | Send a JSON-serializable object over the socket. |
79 | 52 | """ |
| 53 | + self._ensure_socket() |
80 | 54 | self.socket.send_json(data) |
81 | 55 |
|
82 | 56 | def recv_json(self): |
83 | 57 | """ |
84 | 58 | Receive a JSON-serializable object from the socket. |
85 | 59 | """ |
| 60 | + self._ensure_socket() |
86 | 61 | return self.socket.recv_json() |
87 | 62 |
|
88 | 63 | def send_pyobj(self, data): |
89 | 64 | """ |
90 | 65 | Send a Pickle-serializable object over the socket. |
91 | 66 | """ |
| 67 | + self._ensure_socket() |
92 | 68 | self.socket.send_pyobj(data) |
93 | 69 |
|
94 | 70 | def recv_pyobj(self): |
95 | 71 | """ |
96 | 72 | Receive a Pickle-serializable object from the socket. |
97 | 73 | """ |
| 74 | + self._ensure_socket() |
98 | 75 | return self.socket.recv_pyobj() |
99 | 76 |
|
100 | | - def pack_aggregated_data(self, data): |
101 | | - """ |
102 | | - Aggregate multiple responses into one and send them to the client. |
103 | | - """ |
104 | | - result = data[0] |
105 | | - if len(data) > 1: |
106 | | - for response in data[1:]: |
107 | | - result.add(response) |
108 | | - result = msgpack.packb([result.to_dict()]) |
109 | | - return result |
110 | | - |
111 | | - def send_multipart(self, req_id, data): |
112 | | - """ |
113 | | - Send a multipart message to the router socket. |
114 | | - """ |
115 | | - if self.router is None: |
116 | | - raise RuntimeError("Router socket not created. Call create_router() first.") |
117 | | - |
118 | | - while self.running: |
119 | | - with self.mutex: |
120 | | - if req_id not in self.req_dict: |
121 | | - try: |
122 | | - client, _, request_id = self.router.recv_multipart(flags=zmq.NOBLOCK) |
123 | | - req_id_str = request_id.decode("utf-8") |
124 | | - self.req_dict[req_id_str] = client |
125 | | - except zmq.Again: |
126 | | - time.sleep(0.001) |
127 | | - continue |
128 | | - else: |
129 | | - break |
130 | | - if self.req_dict[req_id] == -1: |
131 | | - if data[-1].finished: |
132 | | - with self.mutex: |
133 | | - self.req_dict.pop(req_id, None) |
134 | | - return |
135 | | - try: |
136 | | - start_send = time.time() |
137 | | - if self.aggregate_send: |
138 | | - result = self.pack_aggregated_data(data) |
139 | | - else: |
140 | | - result = msgpack.packb([response.to_dict() for response in data]) |
141 | | - self.router.send_multipart([self.req_dict[req_id], b"", result]) |
142 | | - zmq_client_logger.info(f"send_multipart result: {req_id} len {len(data)} elapse: {time.time()-start_send}") |
143 | | - except zmq.ZMQError as e: |
144 | | - zmq_client_logger.error(f"[{req_id}] zmq error: {e}") |
145 | | - self.req_dict[req_id] = -1 |
146 | | - except Exception as e: |
147 | | - zmq_client_logger.error(f"Send result to zmq client failed: {e}, {str(traceback.format_exc())}") |
| 77 | + @abstractmethod |
| 78 | + def close(self): |
| 79 | + pass |
148 | 80 |
|
149 | | - if data[-1].finished: |
150 | | - with self.mutex: |
151 | | - self.req_dict.pop(req_id, None) |
152 | | - zmq_client_logger.info(f"send_multipart finished, req_id: {req_id}") |
153 | 81 |
|
154 | | - def receive_json_once(self, block=False): |
155 | | - """ |
156 | | - Receive a single message from the socket. |
157 | | - """ |
158 | | - if self.socket is None or self.socket.closed: |
159 | | - return "zmp socket has closed", None |
160 | | - try: |
161 | | - flags = zmq.NOBLOCK if not block else 0 |
162 | | - return None, self.socket.recv_json(flags=flags) |
163 | | - except zmq.Again: |
164 | | - return None, None |
165 | | - except Exception as e: |
166 | | - self.close() |
167 | | - zmq_client_logger.warning(f"{e}, {str(traceback.format_exc())}") |
168 | | - return str(e), None |
| 82 | +class ZmqIpcClient(ZmqClientBase): |
| 83 | + def __init__(self, name, mode): |
| 84 | + self.name = name |
| 85 | + self.mode = mode |
| 86 | + self.file_name = f"/dev/shm/{name}.socket" |
| 87 | + self.context = zmq.Context() |
| 88 | + self.socket = self.context.socket(self.mode) |
169 | 89 |
|
170 | | - def receive_pyobj_once(self, block=False): |
171 | | - """ |
172 | | - Receive a single message from the socket. |
173 | | - """ |
174 | | - if self.socket is None or self.socket.closed: |
175 | | - return "zmp socket has closed", None |
176 | | - try: |
177 | | - flags = zmq.NOBLOCK if not block else 0 |
178 | | - return None, self.socket.recv_pyobj(flags=flags) |
179 | | - except zmq.Again: |
180 | | - return None, None |
181 | | - except Exception as e: |
182 | | - self.close() |
183 | | - zmq_client_logger.warning(f"{e}, {str(traceback.format_exc())}") |
184 | | - return str(e), None |
| 90 | + def _create_socket(self): |
| 91 | + """create and return a ZeroMQ socket.""" |
| 92 | + self.context = zmq.Context() |
| 93 | + return self.context.socket(self.mode) |
185 | 94 |
|
186 | | - def _clear_ipc(self, name): |
187 | | - """ |
188 | | - Remove the IPC file with the given name. |
189 | | - """ |
190 | | - if os.path.exists(name): |
191 | | - try: |
192 | | - os.remove(name) |
193 | | - except OSError as e: |
194 | | - zmq_client_logger.warning(f"Failed to remove IPC file {name} - {e}") |
| 95 | + def connect(self): |
| 96 | + self._ensure_socket() |
| 97 | + self.socket.connect(f"ipc://{self.file_name}") |
195 | 98 |
|
196 | 99 | def close(self): |
197 | 100 | """ |
198 | | - Close the socket and context, and remove the IPC files. |
| 101 | + Close the socket and context. |
199 | 102 | """ |
200 | | - if not self.running: |
201 | | - return |
202 | | - |
203 | | - self.running = False |
204 | | - zmq_client_logger.info("Closing ZMQ connection...") |
| 103 | + llm_logger.info("ZMQ client is closing connection...") |
205 | 104 | try: |
206 | | - if hasattr(self, "socket") and not self.socket.closed: |
| 105 | + if self.socket is not None and not self.socket.closed: |
| 106 | + self.socket.setsockopt(zmq.LINGER, 0) |
207 | 107 | self.socket.close() |
208 | | - |
209 | | - if self.router is not None and not self.router.closed: |
210 | | - self.router.close() |
211 | | - |
212 | | - if not self.context.closed: |
| 108 | + if self.context is not None: |
213 | 109 | self.context.term() |
214 | 110 |
|
215 | | - self._clear_ipc(self.file_name) |
216 | | - self._clear_ipc(self.router_path) |
217 | 111 | except Exception as e: |
218 | | - zmq_client_logger.warning(f"Failed to close ZMQ connection - {e}, {str(traceback.format_exc())}") |
| 112 | + llm_logger.warning(f"ZMQ client failed to close connection - {e}") |
219 | 113 | return |
220 | | - |
221 | | - def __exit__(self, exc_type, exc_val, exc_tb): |
222 | | - self.close() |
0 commit comments