-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathserver_socket_simulator.py
363 lines (336 loc) · 13.2 KB
/
server_socket_simulator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
import multiprocessing
import random
import socket
import threading
from Queue import Empty
from collections import Counter
import time
from ip import IPDatagram
from logger import LOG
from tcp import TCPSegment
class ServerSocketSimulator:
'''
No hear bit implemented
'''
def __init__(self, host=None, timeout=180, tick=2):
# IP
self.host = host
self.ip_src = socket.inet_aton(self.host.intList[0].ip_addr)
self.ip_dest = ''
# port
self.port_src = 80
self.port_dest = 0
# TCP
self.tcp_seq = random.randint(0x0001, 0xffff)
self.tcp_ack_seq = 0
self.prev_data = ''
self.tcp_cwind = 1024
# size of the receive buffer
self.tcp_adwind = 65535
self.recv_buf = []
self.tmp_buf = {}
self.tick = tick
self.max_retry = timeout / tick
self.metrics = Counter(send=0, recv=0, erecv=0,
retry=0, cksumfail=0)
self.syn_queue = multiprocessing.Queue()
self.ack_queue = multiprocessing.Queue()
self.fin_queue = multiprocessing.Queue()
self.receiving_tcp = threading.Thread(target=self.receive_tcp)
self.receiving_tcp.start()
def receive_tcp(self):
while True:
tcp_data = self.recv()
if tcp_data:
print self.host.name + ":received message:" + tcp_data
# self.host.chat_window.message_queue.put(tcp_data)
def accept(self):
'''
Accept to the given hostname and port
'''
# 3-way handshake
tcp_segment = self.listen()
self.port_dest = tcp_segment.tcp_src_port
self._tcp_handshake(tcp_segment)
print "........server........%s has connected with %s........." %\
(self.host.name, socket.inet_ntoa(self.ip_dest))
listening_down = threading.Thread(target=self.listen_tear_down)
listening_down.start()
def listen_tear_down(self):
'''
Accept to the given hostname and port
'''
while True:
# 3-way handshake
tcp_segment = self.recv_non_psh(self.fin_queue)
self.port_dest = tcp_segment.tcp_src_port
self._tcp_response_tear_down(tcp_segment)
self.close()
print "........server........%s has disconnected with %s........." %\
(self.host.name, socket.inet_ntoa(self.ip_dest))
time.sleep(0.001)
new_accept = threading.Thread(target=self.host.chat_window.keep_accept)
new_accept.start()
def bind(self):
self.ip_src = socket.inet_aton(self.host.intList[0].ip_addr)
def listen(self):
LOG.info(self.host.name + " : TCP server socket listening...")
time.sleep(0.01)
return self.recv_non_psh(self.syn_queue)
def send(self, data=''):
'''
Send all the given data, the TCP congestion control
goes here, so that data might be sliced
'''
send_length = 0
total_len = len(data)
while send_length < total_len:
self.tcp_cwind += 1/self.tcp_cwind
self._send(data[send_length:(send_length + self.tcp_cwind)], ack=1, psh=1)
# update TCP seq
if (send_length + self.tcp_cwind) > total_len:
self.tcp_seq += (total_len - send_length)
else:
self.tcp_seq += self.tcp_cwind
send_length += self.tcp_cwind
return total_len
def recv_non_psh(self,queue):
time.sleep(0.01)
while True:
try:
tcp_segment = queue.get(0)
if tcp_segment.tcp_dest_port == self.port_src:
return tcp_segment
except Empty:
pass
finally:
time.sleep(0.1)
def recv(self, bufsize=8192):
'''
Receive the data with the given buffer size,
the receiving buffer gets maintained here
'''
receive_length = 0
tcp_data = ''
times = 1 + bufsize / self.tcp_adwind
fin = False
while times:
while receive_length < self.tcp_adwind:
tcp_segment = self._recv(self.max_retry)
if tcp_segment is None:
raise RuntimeError('Connection timeout')
if not tcp_segment.tcp_fpsh:
if tcp_segment.tcp_fack:
self.ack_queue.put(tcp_segment)
if tcp_segment.tcp_ffin:
self.fin_queue.put(tcp_segment)
if tcp_segment.tcp_fsyn:
self.syn_queue.put(tcp_segment)
return None
elif tcp_segment.tcp_fack:
if tcp_segment.tcp_seq == self.tcp_ack_seq:
LOG.debug('Recv in-order TCP segment')
receive_length += self._enbuf(tcp_segment)
self._send(ack=1)
if tcp_segment.tcp_ffin:
fin = True
break
while self.tcp_ack_seq in self.tmp_buf:
tcp_segment = self.tmp_buf[self.tcp_ack_seq]
receive_length += self._enbuf(tcp_segment)
if tcp_segment.tcp_ffin:
fin = True
break
self._send(ack=1)
if tcp_segment.data:
self.host.chat_window.message_queue.put(tcp_segment.data)
if fin:
break
elif (tcp_segment.tcp_seq > self.tcp_ack_seq) and \
(tcp_segment.tcp_seq not in self.tmp_buf):
LOG.debug('Recv out-of-order TCP segment')
self.tmp_buf[tcp_segment.tcp_seq] = tcp_segment
else:
continue
tcp_data = ''.join([tcp_data, self._debuf()])
if fin:
return tcp_data
times -= 1
return tcp_data
def close(self):
'''
Tear down the raw socket connection
'''
# self.socket.close() remove the host
while not self.ack_queue.empty():
self.ack_queue.get()
while not self.fin_queue.empty():
self.fin_queue.get()
while not self.syn_queue.empty():
self.syn_queue.get()
self.host.chat_window.current_socket = None
def _tcp_handshake(self, tcp_segment):
'''
Wrap the TCP 3-way handshake procedure
'''
# check timeout
if tcp_segment is None:
raise RuntimeError('TCP Server handshake failed, connection timeout')
# check server SYN
if not tcp_segment.tcp_fsyn:
raise RuntimeError('TCP Server handshake failed, bad server response')
# send back
self._send(syn=1, ack=1)
def _send(self, data='', retry=False, urg=0, ack=0, psh=0,
rst=0, syn=0, fin=0):
'''
Send the given data within a packet the set TCP flags,
return the number of bytes sent.
'''
if retry:
self.tcp_cwind /= 2
return self.host.send_datagram(self.prev_data)
else:
# build TCP segment
tcp_segment = TCPSegment(ip_src_addr=self.ip_src,
ip_dest_addr=self.ip_dest,
tcp_src_port=self.port_src,
tcp_dest_port=self.port_dest,
tcp_seq=self.tcp_seq,
tcp_ack_seq=self.tcp_ack_seq,
tcp_furg=urg, tcp_fack=ack, tcp_fpsh=psh,
tcp_frst=rst, tcp_fsyn=syn, tcp_ffin=fin,
tcp_adwind=self.tcp_cwind, data=data)
ip_data = tcp_segment.pack()
print self.host.name + ": server send TCP: " + tcp_segment.__repr__()
# build IP datagram
ip_datagram = IPDatagram(ip_src_addr=self.ip_src,
ip_dest_addr=self.ip_dest,
data=ip_data)
eth_data = ip_datagram.pack()
self.metrics['send'] += 1
self.prev_data = eth_data
return self.host.send_datagram(eth_data)
def _recv(self, max_retry, bufsize=1500):
'''
Receive a packet with the given buffer size, will not retry
for per-packet failure until using up max retry
'''
while max_retry:
self.metrics['recv'] += 1
# wait with timeout for the readable socket
# socket is ready to read, no timeout
try:
# process Ethernet frame
ip_bytes = self.host.tcp_ip_queue.get()
ip_datagram = IPDatagram("", "", data="")
ip_datagram.unpack(ip_bytes)
# IP filtering
if not self._ip_expected(ip_datagram):
continue
# IP checksum
if not ip_datagram.verify_checksum():
return self._retry(bufsize, max_retry)
self.ip_dest = ip_datagram.ip_src_addr
# process TCP segment
ip_data = ip_datagram.data
tcp_segment = TCPSegment(self.ip_src, self.ip_dest)
tcp_segment.unpack(ip_data)
tcp_segment.pack()
# TCP filtering
if not self._tcp_expected(tcp_segment):
continue
# TCP checksum
# if not tcp_segment.verify_checksum():
# self.metrics['cksumfail'] += 1
# return self._retry(bufsize, max_retry)
LOG.debug('Recv: %s' % tcp_segment)
self.metrics['erecv'] += 1
print self.host.name + "Server Receive TCP: " + tcp_segment.__repr__()
return tcp_segment
# timeout, re-_send and re-_recv
except Empty:
return self._retry(bufsize, max_retry)
return None
def _ip_expected(self, ip_datagram):
'''
Return True if the received ip_datagram is the
expected one.
1. ip_ver should be 4
2. ip_src_addr should be the expected dest machine
3. ip_proto identifier should be TCP(6)
'''
if ip_datagram.ip_ver != 4:
return False
elif ip_datagram.ip_dest_addr != self.ip_src:
return False
elif ip_datagram.ip_proto != socket.IPPROTO_TCP:
return False
else:
return True
def _retry(self, bufsize, max_retry):
'''
Re-_send and re-_recv with the max retry -1
Mutual recursion with self._recv(bufsize)
'''
self.metrics['retry'] += 1
max_retry -= 1
self._send(retry=True, ack=1)
return self._recv(bufsize, max_retry)
def _enbuf(self, tcp_segment):
'''
Put the in-order TCP payload into recv buffer
'''
self.recv_buf.append(tcp_segment.data)
elen = len(tcp_segment.data)
self.tcp_seq = tcp_segment.tcp_ack_seq
self.tcp_ack_seq += elen
# self._send(ack=1)
return elen
def _debuf(self):
'''
Dump all cached TCP payload out from the recv buffer
'''
tcp_data = ''
for data_slice in self.recv_buf:
tcp_data = ''.join([tcp_data, data_slice])
del self.recv_buf[:]
self.tmp_buf.clear()
return tcp_data
def _tcp_expected(self, tcp_segment):
'''
Return True if the received tcp_segment is the
expected one.
1. tcp_src_port should be the local dest port
2. tcp_dest_port should be the local src port
3. raise error if server resets the connection
4. checksum must be valid
'''
if tcp_segment.tcp_dest_port != self.port_src:
return False
elif tcp_segment.tcp_frst:
raise RuntimeError('Connection reset by server')
else:
return True
def dump_metrics(self):
'''
Dump the metrics counters for debug usage
'''
dump = '\n'.join('\t%s: %d' % (k, v) for (k, v)
in self.metrics.items())
return dump, self.metrics
def _tcp_response_tear_down(self, tcp_segment):
if tcp_segment is None:
raise RuntimeError('TCP teardown failed, connection timeout')
# check server FIN
print self.host.name + " from _tcp_response_tear_down:" + tcp_segment.__repr__()
print self.host.name + " from _tcp_response_tear_down fin: %d" % tcp_segment.tcp_ffin
if not tcp_segment.tcp_ffin:
raise RuntimeError('TCP teardown failed, server not FIN')
self.tcp_seq = tcp_segment.tcp_ack_seq
self.tcp_ack_seq = tcp_segment.tcp_seq + 1
self._send(ack=1)
self._send(fin=1)
# check server ACK
if not tcp_segment.tcp_ffin:
raise RuntimeError('TCP teardown failed, server not ACK')