-
Notifications
You must be signed in to change notification settings - Fork 6
/
client.py
executable file
·148 lines (123 loc) · 4.75 KB
/
client.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
#!/usr/bin/env python3
# Created @ 2016-03-18 16:23 by @radaiming
#
import argparse
import asyncio
import logging
import struct
import sys
import websockets
try:
asyncio_ensure_future = asyncio.ensure_future
except AttributeError:
asyncio_ensure_future = asyncio.async
query_queue = None
listen_transport = None
def byte_2_domain(data):
# >>> struct.unpack('!6c', b'\x01\x74\x02\x63\x6e\x00')
# (b'\x01', b't', b'\x02', b'c', b'n', b'\x00')
domain = b''
try:
length = struct.unpack('!B', data[0:1])[0]
i = 1
while data[i:i+1] != b'\x00':
if length == 0:
domain += b'.'
length = struct.unpack('!B', data[i:i+1])[0]
else:
domain += data[i:i+1]
length -= 1
i += 1
return domain.decode('utf-8')
except struct.error:
return 'unknown domain'
@asyncio.coroutine
def send_to_server(ws, client_addr, data):
# use client addr as key to identify returned packet
# later we could use \x00\x00 to split
b_client_addr = ('%s:%d' % client_addr).encode('utf-8')
packed_data = b_client_addr + b'\x00\x00' + data
logging.debug('querying ' + byte_2_domain(data[12:]))
yield from ws.send(packed_data)
@asyncio.coroutine
def receive_data(ws):
# long running coroutine, receive result from server,
# then call send_back_to_clent() to send back result
try:
while True:
data = yield from ws.recv()
asyncio_ensure_future(send_back_to_client(data))
except websockets.exceptions.ConnectionClosed:
pass
@asyncio.coroutine
def send_back_to_client(packed_data):
b_client_addr, data = packed_data.split(b'\x00\x00', 1)
b_client_ip, b_client_port = b_client_addr.split(b':')
client_addr = (b_client_ip.decode('utf-8'), int(b_client_port))
listen_transport.sendto(data, client_addr)
logging.debug('result sending: ' + byte_2_domain(data[12:]))
@asyncio.coroutine
def connect_ws_server(server_addr):
ws = None
while True:
# consumer
try:
client_addr, data = yield from query_queue.get()
if not ws or not ws.open:
if not ws:
logging.info('connecting to server...')
else:
logging.info('lost connection to server, reconnecting...')
ws = yield from websockets.connect(server_addr)
logging.info('server connected')
asyncio_ensure_future(receive_data(ws))
# let it go background, do not block this loop
asyncio_ensure_future(send_to_server(ws, client_addr, data))
except Exception as exp:
# when there's connection issue, there could be many kinds
# of exceptions, I don't want to test and write one by one
logging.error('connection error: ' + str(exp))
class ListenProtocol(asyncio.DatagramProtocol):
def __init__(self):
logging.info('listening for incoming query')
def connection_made(self, transport):
global listen_transport
listen_transport = transport
def datagram_received(self, data, client_addr):
global query_queue
# producer
asyncio_ensure_future(query_queue.put((client_addr, data)))
def main():
global query_queue
parser = argparse.ArgumentParser(description='A simple DNS tunnel over websocket')
parser.add_argument('-c', action='store', dest='server_addr', required=True,
help='set server url, like ws://test.com/dns')
parser.add_argument('-b', action='store', dest='bind_address', default='127.0.0.1',
help='bind to this address, default to 127.0.0.1')
parser.add_argument('-p', action='store', dest='bind_port', type=int, default=5353,
help='bind to this port, default to 5353')
parser.add_argument('--debug', action='store_true', dest='debug', default=False,
help='enable debug outputing')
args = parser.parse_args(sys.argv[1:])
if args.debug:
logging_level = logging.DEBUG
else:
logging_level = logging.INFO
logging.basicConfig(level=logging_level,
format='%(asctime)s %(levelname)-8s %(message)s',
datefmt='%Y-%m-%d %H:%M:%S')
loop = asyncio.get_event_loop()
query_queue = asyncio.Queue(loop=loop)
listen = loop.create_datagram_endpoint(
ListenProtocol,
local_addr=(args.bind_address, args.bind_port))
asyncio_ensure_future(listen)
asyncio_ensure_future(connect_ws_server(args.server_addr))
try:
loop.run_forever()
except KeyboardInterrupt:
pass
listen_transport.close()
loop.close()
if __name__ == '__main__':
main()