|
| 1 | +#!/usr/bin/python |
| 2 | + |
| 3 | +# Copyright (c) 2009 Tom Pinckney |
| 4 | +# |
| 5 | +# Permission is hereby granted, free of charge, to any person |
| 6 | +# obtaining a copy of this software and associated documentation |
| 7 | +# files (the "Software"), to deal in the Software without |
| 8 | +# restriction, including without limitation the rights to use, |
| 9 | +# copy, modify, merge, publish, distribute, sublicense, and/or sell |
| 10 | +# copies of the Software, and to permit persons to whom the |
| 11 | +# Software is furnished to do so, subject to the following |
| 12 | +# conditions: |
| 13 | +# |
| 14 | +# The above copyright notice and this permission notice shall be |
| 15 | +# included in all copies or substantial portions of the Software. |
| 16 | +# |
| 17 | +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, |
| 18 | +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES |
| 19 | +# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND |
| 20 | +# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT |
| 21 | +# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, |
| 22 | +# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING |
| 23 | +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR |
| 24 | +# OTHER DEALINGS IN THE SOFTWARE. |
| 25 | + |
| 26 | +import sys |
| 27 | +import socket |
| 28 | +import struct |
| 29 | +import ConfigParser |
| 30 | +import signal |
| 31 | +import getopt |
| 32 | + |
| 33 | +from utils import * |
| 34 | + |
| 35 | +class DnsError(Exception): |
| 36 | + pass |
| 37 | + |
| 38 | +def serve(): |
| 39 | + udps = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) |
| 40 | + udps.bind((listen_host, listen_port)) |
| 41 | + #ns_resource_records, ar_resource_records = compute_name_server_resources(_name_servers) |
| 42 | + ns_resource_records = ar_resource_records = [] |
| 43 | + while True: |
| 44 | + try: |
| 45 | + req_pkt, src_addr = udps.recvfrom(512) # max UDP DNS pkt size |
| 46 | + except socket.error: |
| 47 | + continue |
| 48 | + qid = None |
| 49 | + try: |
| 50 | + exception_rcode = None |
| 51 | + try: |
| 52 | + qid, question, qtype, qclass = parse_request(req_pkt) |
| 53 | + except: |
| 54 | + exception_rcode = 1 |
| 55 | + raise Exception("could not parse query") |
| 56 | + question = map(lambda x: x.lower(), question) |
| 57 | + found = False |
| 58 | + for config in config_files.values(): |
| 59 | + if question[1:] == config['domain']: |
| 60 | + query = question[0] |
| 61 | + elif question == config['domain']: |
| 62 | + query = '' |
| 63 | + else: |
| 64 | + continue |
| 65 | + rcode, an_resource_records = config['source'].get_response(query, config['domain'], qtype, qclass, src_addr) |
| 66 | + if rcode == 0 and 'filters' in config: |
| 67 | + for f in config['filters']: |
| 68 | + an_resource_records = f.filter(query, config['domain'], qtype, qclass, src_addr, an_resource_records) |
| 69 | + resp_pkt = format_response(qid, question, qtype, qclass, rcode, an_resource_records, ns_resource_records, ar_resource_records) |
| 70 | + found = True |
| 71 | + break |
| 72 | + if not found: |
| 73 | + exception_rcode = 3 |
| 74 | + raise Exception("query is not for our domain: %s" % ".".join(question)) |
| 75 | + except: |
| 76 | + if qid: |
| 77 | + if exception_rcode is None: |
| 78 | + exception_rcode = 2 |
| 79 | + resp_pkt = format_response(qid, question, qtype, qclass, exception_rcode, [], [], []) |
| 80 | + else: |
| 81 | + continue |
| 82 | + udps.sendto(resp_pkt, src_addr) |
| 83 | + |
| 84 | +def compute_name_server_resources(name_servers): |
| 85 | + ns = [] |
| 86 | + ar = [] |
| 87 | + for name_server, ip, ttl in name_servers: |
| 88 | + ns.append({'qtype':2, 'qclass':1, 'ttl':ttl, 'rdata':labels2str(name_server)}) |
| 89 | + ar.append({'qtype':1, 'qclass':1, 'ttl':ttl, 'rdata':struct.pack("!I", ip)}) |
| 90 | + return ns, ar |
| 91 | + |
| 92 | +def parse_request(packet): |
| 93 | + hdr_len = 12 |
| 94 | + header = packet[:hdr_len] |
| 95 | + qid, flags, qdcount, _, _, _ = struct.unpack('!HHHHHH', header) |
| 96 | + qr = (flags >> 15) & 0x1 |
| 97 | + opcode = (flags >> 11) & 0xf |
| 98 | + rd = (flags >> 8) & 0x1 |
| 99 | + #print "qid", qid, "qdcount", qdcount, "qr", qr, "opcode", opcode, "rd", rd |
| 100 | + if qr != 0 or opcode != 0 or qdcount == 0: |
| 101 | + raise DnsError("Invalid query") |
| 102 | + body = packet[hdr_len:] |
| 103 | + labels = [] |
| 104 | + offset = 0 |
| 105 | + while True: |
| 106 | + label_len, = struct.unpack('!B', body[offset:offset+1]) |
| 107 | + offset += 1 |
| 108 | + if label_len & 0xc0: |
| 109 | + raise DnsError("Invalid label length %d" % label_len) |
| 110 | + if label_len == 0: |
| 111 | + break |
| 112 | + label = body[offset:offset+label_len] |
| 113 | + offset += label_len |
| 114 | + labels.append(label) |
| 115 | + qtype, qclass= struct.unpack("!HH", body[offset:offset+4]) |
| 116 | + if qclass != 1: |
| 117 | + raise DnsError("Invalid class: " + qclass) |
| 118 | + return (qid, labels, qtype, qclass) |
| 119 | + |
| 120 | +def format_response(qid, question, qtype, qclass, rcode, an_resource_records, ns_resource_records, ar_resource_records): |
| 121 | + resources = [] |
| 122 | + resources.extend(an_resource_records) |
| 123 | + num_an_resources = len(an_resource_records) |
| 124 | + num_ns_resources = num_ar_resources = 0 |
| 125 | + if rcode == 0: |
| 126 | + resources.extend(ns_resource_records) |
| 127 | + resources.extend(ar_resource_records) |
| 128 | + num_ns_resources = len(ns_resource_records) |
| 129 | + num_ar_resources = len(ar_resource_records) |
| 130 | + pkt = format_header(qid, rcode, num_an_resources, num_ns_resources, num_ar_resources) |
| 131 | + pkt += format_question(question, qtype, qclass) |
| 132 | + for resource in resources: |
| 133 | + pkt += format_resource(resource, question) |
| 134 | + return pkt |
| 135 | + |
| 136 | +def format_header(qid, rcode, ancount, nscount, arcount): |
| 137 | + flags = 0 |
| 138 | + flags |= (1 << 15) |
| 139 | + flags |= (1 << 10) |
| 140 | + flags |= (rcode & 0xf) |
| 141 | + hdr = struct.pack("!HHHHHH", qid, flags, 1, ancount, nscount, arcount) |
| 142 | + return hdr |
| 143 | + |
| 144 | +def format_question(question, qtype, qclass): |
| 145 | + q = labels2str(question) |
| 146 | + q += struct.pack("!HH", qtype, qclass) |
| 147 | + return q |
| 148 | + |
| 149 | +def format_resource(resource, question): |
| 150 | + r = '' |
| 151 | + r += labels2str(question) |
| 152 | + r += struct.pack("!HHIH", resource['qtype'], resource['qclass'], resource['ttl'], len(resource['rdata'])) |
| 153 | + r += resource['rdata'] |
| 154 | + return r |
| 155 | + |
| 156 | +def read_config(): |
| 157 | + for config_file in config_files: |
| 158 | + config_files[config_file] = config = {} |
| 159 | + config_parser = ConfigParser.SafeConfigParser() |
| 160 | + try: |
| 161 | + config_parser.read(config_file) |
| 162 | + config_values = config_parser.items("default") |
| 163 | + except: |
| 164 | + die("Error reading config file %s\n" % config_file) |
| 165 | + |
| 166 | + for var, value in config_values: |
| 167 | + if var == "domain": |
| 168 | + config['domain'] = value.split(".") |
| 169 | + elif var == "name servers": |
| 170 | + config['name_servers'] = [] |
| 171 | + split_name_servers = value.split(":") |
| 172 | + num_split_name_servers = len(split_name_servers) |
| 173 | + for i in range(0,num_split_name_servers,3): |
| 174 | + server = split_name_servers[i] |
| 175 | + ip = split_name_servers[i+1] |
| 176 | + ttl = int(split_name_servers[i+2]) |
| 177 | + config['name_servers'].append((server.split("."), ipstr2int(ip), ttl)) |
| 178 | + elif var == 'source': |
| 179 | + module_and_args = value.split(":") |
| 180 | + module = module_and_args[0] |
| 181 | + args = module_and_args[1:] |
| 182 | + source_module = __import__(module, {}, {}, ['']) |
| 183 | + source_instance = source_module.Source(*args) |
| 184 | + config['source'] = source_instance |
| 185 | + elif var == 'filters': |
| 186 | + config['filters'] = [] |
| 187 | + for module_and_args_str in value.split(): |
| 188 | + module_and_args = module_and_args_str.split(":") |
| 189 | + module = module_and_args[0] |
| 190 | + args = module_and_args[1:] |
| 191 | + filter_module = __import__(module, {}, {}, ['']) |
| 192 | + filter_instance = filter_module.Filter(*args) |
| 193 | + config['filters'].append(filter_instance) |
| 194 | + else: |
| 195 | + die("unrecognized paramter in conf file %s: %s\n" % (config_file, var)) |
| 196 | + |
| 197 | + if 'domain' not in config or 'source' not in config: |
| 198 | + die("must specify domain name and source in conf file %s\n", config_file) |
| 199 | + sys.stderr.write("read configuration from %s\n" % config_file) |
| 200 | + |
| 201 | +def reread(signum, frame): |
| 202 | + read_config() |
| 203 | + |
| 204 | +def die(msg): |
| 205 | + sys.stderr.write(msg) |
| 206 | + sys.exit(-1) |
| 207 | + |
| 208 | +def usage(cmd): |
| 209 | + die("Usage: %s [conf file]\n" % cmd) |
| 210 | + |
| 211 | +config_files = {} |
| 212 | +listen_port = 53 |
| 213 | +listen_host = '' |
| 214 | + |
| 215 | +try: |
| 216 | + options, filenames = getopt.getopt(sys.argv[1:], "p:h:") |
| 217 | +except getopt.GetoptError: |
| 218 | + usage(sys.argv[0]) |
| 219 | + |
| 220 | +for option, value in options: |
| 221 | + if option == "-p": |
| 222 | + listen_port = int(value) |
| 223 | + elif option == "-h": |
| 224 | + listen_host = value |
| 225 | +if not filenames: |
| 226 | + filenames = ['pymds.conf'] |
| 227 | +for f in filenames: |
| 228 | + if f in config_files: |
| 229 | + raise Exception("repeated configuration") |
| 230 | + config_files[f] = {} |
| 231 | + |
| 232 | +sys.stdout.write("%s starting on port %d\n" % (sys.argv[0], listen_port)) |
| 233 | +read_config() |
| 234 | +signal.signal(signal.SIGHUP, reread) |
| 235 | +for config in config_files.values(): |
| 236 | + sys.stdout.write("%s: serving for domain %s\n" % (sys.argv[0], ".".join(config['domain']))) |
| 237 | +sys.stdout.flush() |
| 238 | +sys.stderr.flush() |
| 239 | +serve() |
0 commit comments