#!/usr/bin/python # Python Domain Name Server # Copyright (C) 2002 Digital Lumber, Inc. # This library is free software; you can redistribute it and/or # modify it under the terms of the GNU Lesser General Public # License as published by the Free Software Foundation; either # version 2.1 of the License, or (at your option) any later version. # This library is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU # Lesser General Public License for more details. # You should have received a copy of the GNU Lesser General Public # License along with this library; if not, write to the Free Software # Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA import socket import asyncore import asynchat import select import types import random import time import signal import string import sys import sipb_xen_database from errno import EALREADY, EINPROGRESS, EWOULDBLOCK, ECONNRESET, \ ENOTCONN, ESHUTDOWN, EINTR, EISCONN, ETIMEDOUT # EXAMPLE ZONE FILE DATA STRUCTURE # NOTE: # There are no trailing dots in the internal data # structure. Although it's hard to tell by reading # the RFC's, the dots on the end of names are just # used internally by the resolvers and servers to # see if they need to append a domain name onto # the end of names. There are no trailing dots # on names in queries on the network. examplenet = {'example.net':{'SOA':[{'class':'IN', 'ttl':10, 'mname':'ns1.example.net', 'rname':'hostmaster.example.net', 'serial':1, 'refresh':10800, 'retry':3600, 'expire':604800, 'minimum':3600}], 'NS':[{'class':'IN', 'ttl':10, 'nsdname':'ns1.example.net'}, {'ttl':10, 'nsdname':'ns2.example.net'}], 'MX':[{'class':'IN', 'ttl':10, 'preference':10, 'exchange':'mail.example.net'}]}, 'server1.example.net':{'A':[{'class':'IN', 'ttl':10, 'address':'10.1.2.3'}]}, 'www.example.net':{'CNAME':[{'class':'IN', 'ttl':10, 'cname':'server1.example.net'}]}, 'router.example.net':{'A':[{'class':'IN', 'ttl':10, 'address':'10.1.2.1'}, {'class':'IN', 'ttl':10, 'address':'10.2.1.1'}]} } # setup logging defaults loglevel = 0 logfile = sys.stdout try: file except NameError: def file(name, mode='r', buffer=0): return open(name, mode, buffer) def log(level,msg): if level <= loglevel: logfile.write(msg+'\n') def timestamp(): return time.strftime('%m/%d/%y %H:%M:%S')+ '-' def inttoasc(number): try: hs = hex(number)[2:] except: log(0,'inttoasc cannot convert ' + repr(number)) if hs[-1:].upper() == 'L': hs = hs[:-1] result = '' while len(hs) > 2: result = chr(int(hs[-2:],16)) + result hs = hs[:-2] result = chr(int(hs,16)) + result return result def asctoint(ascnum): rascnum = '' for i in range(len(ascnum)-1,-1,-1): rascnum = rascnum + ascnum[i] result = 0 count = 0 for c in rascnum: x = ord(c) << (8*count) result = result + x count = count + 1 return result def ipv6net_aton(ip_string): packed_ip = '' # first account for shorthand syntax pieces = ip_string.split(':') pcount = 0 for part in pieces: if part != '': pcount = pcount + 1 if pcount < 8: rs = '0:'*(8-pcount) ip_string = ip_string.replace('::',':'+rs) if ip_string[0] == ':': ip_string = ip_string[1:] pieces = ip_string.split(':') for part in pieces: # pad with the zeros i = 4-len(part) part = i*'0'+part packed_ip = packed_ip + chr(int(part[:2],16))+ chr(int(part[2:],16)) return packed_ip def ipv6net_ntoa(packed_ip): ip_string = '' count = 0 for c in packed_ip: ip_string = ip_string + hex(ord(c))[2:] count = count + 1 if count == 2: ip_string = ip_string + ':' count = 0 return ip_string[:-1] def getversion(qname, id, rd, ra, versionstr): msg = message() msg.header.id = id msg.header.qr = 1 msg.header.aa = 1 msg.header.rd = rd msg.header.ra = ra msg.header.rcode = 0 msg.question.qname = qname msg.question.qtype = 'TXT' msg.question.qclass = 'CH' if qname == 'version.bind': msg.header.ancount = 2 msg.answerlist.append({qname:{'CNAME':[{'cname':'version.oak', 'ttl':360000, 'class':'CH'}]}}) msg.answerlist.append({'version.oak':{'TXT':[{'txtdata':versionstr, 'ttl':360000, 'class':'CH'}]}}) else: msg.header.ancount = 1 msg.answerlist.append({qname:{'TXT':[{'txtdata':versionstr, 'ttl':360000, 'class':'CH'}]}}) return msg def getrcode(rcode): if rcode == 0: rcodestr = 'NOERROR(No error condition)' elif rcode == 1: rcodestr = 'FORMERR(Format Error)' elif rcode == 2: rcodestr = 'SERVFAIL(Internal failure)' elif rcode == 3: rcodestr = 'NXDOMAIN(Name does not exist)' elif rcode == 4: rcodestr = 'NOTIMP(Not Implemented)' elif rcode == 5: rcodestr = 'REFUSED(Security violation)' elif rcode == 6: rcodestr = 'YXDOMAIN(Name exists)' elif rcode == 7: rcodestr = 'YXRRSET(RR exists)' elif rcode == 8: rcodestr = 'NXRRSET(RR does not exist)' elif rcode == 9: rcodestr = 'NOTAUTH(Server not Authoritative)' elif rcode == 10: rcodestr = 'NOTZONE(Name not in zone)' else: rcodestr = 'Unknown RCODE(' + str(rcode) + ')' return rcodestr def printrdata(dnstype, rdata): if dnstype == 'A': return rdata['address'] elif dnstype == 'MX': return str(rdata['preference'])+'\t'+rdata['exchange']+'.' elif dnstype == 'NS': return rdata['nsdname']+'.' elif dnstype == 'PTR': return rdata['ptrdname']+'.' elif dnstype == 'CNAME': return rdata['cname']+'.' elif dnstype == 'SOA': return (rdata['mname']+'.\t'+rdata['rname']+'. (\n'+35*' '+str(rdata['serial'])+'\n'+ 35*' '+str(rdata['refresh'])+'\n'+35*' '+str(rdata['retry'])+'\n'+35*' '+ str(rdata['expire'])+'\n'+35*' '+str(rdata['minimum'])+' )') def makezonedatalist(zonedata, origin): # unravel structure into list zonedatalist = [] # get soa first soanode = zonedata[origin] zonedatalist.append([origin+'.','SOA',soanode['SOA'][0]]) for item in soanode.keys(): if item != 'SOA': for listitem in soanode[item]: zonedatalist.append([origin+'.', item, listitem]) for nodename in zonedata.keys(): if nodename != origin: for item in zonedata[nodename].keys(): for listitem in zonedata[nodename][item]: zonedatalist.append([nodename+'.', item, listitem]) return zonedatalist def writezonefile(zonedata, origin, file): zonedatalist = makezonedatalist(zonedata, origin) for rr in zonedatalist: owner = rr[0] dnstype = rr[1] line = (owner + (35-len(owner))*' ' + str(rr[2]['ttl']) + '\t\tIN\t' + dnstype + '\t' + printrdata(dnstype, rr[2])) file.write(line + '\n') def readzonefiles(zonedict): for k in zonedict.keys(): filepath = zonedict[k]['filename'] try: pr = zonefileparser() pr.parse(zonedict[k]['origin'],filepath) zonedict[k]['zonedata'] = pr.getzdict() except ZonefileError, lineno: log(0,'Error reading zone file ' + filepath + ' at line ' + str(lineno) + '\n') del zonedict[k] def slowloop(tofunc='',timeout=5.0): if not tofunc: def tofunc(): return map = asyncore.socket_map while map: r = []; w=[]; e=[] for fd, obj in map.items(): if obj.readable(): r.append(fd) if obj.writable(): w.append(fd) try: starttime = time.time() r,w,e = select.select(r,w,e,timeout) endtime = time.time() if endtime-starttime >= timeout: tofunc() except select.error, err: if err[0] != EINTR: raise r=[]; w=[]; e=[] log(0,'ERROR in select') for fd in r: try: obj=map[fd] except KeyError: log(0,'KeyError in socket map') continue try: obj.handle_read_event() except: log(0,'calling HANDLE ERROR from loop') log(0,repr(obj)) obj.handle_error() for fd in w: try: obj=map[fd] except KeyError: log(0,'KeyError in socket map') continue try: obj.handle_read_event() except: log(0,'calling HANDLE ERROR from loop') log(0,repr(obj)) obj.handle_error() def fastloop(tofunc='',timeout=5.0): if not tofunc: def tofunc(): return polltimeout = timeout*1000 map = asyncore.socket_map while map: regfds = 0 pollobj = select.poll() for fd, obj in map.items(): flags = 0 if obj.readable(): flags = select.POLLIN if obj.writable(): flags = flags | select.POLLOUT if flags: pollobj.register(fd, flags) regfds = regfds + 1 try: starttime = time.time() r = pollobj.poll(polltimeout) endtime = time.time() if endtime-starttime >= timeout: tofunc() except select.error, err: if err[0] != EINTR: raise r = [] log(0,'ERROR in select') for fd, flags in r: try: obj = map[fd] badvals = (select.POLLPRI + select.POLLERR + select.POLLHUP + select.POLLNVAL) if (flags & badvals): if (flags & select.POLLPRI): log(0,'POLLPRI') if (flags & select.POLLERR): log(0,'POLLERR') if (flags & select.POLLHUP): log(0,'POLLHUP') if (flags & select.POLLNVAL): log(0,'POLLNVAL') obj.handle_error() else: if (flags & select.POLLIN): obj.handle_read_event() if (flags & select.POLLOUT): obj.handle_write_event() except KeyError: log(0,'KeyError in socket map') continue except: # print traceback sf = StringIO.StringIO() traceback.print_exc(file=sf) log(0,'ERROR IN LOOP:') log(0,sf.getvalue()) sf.close() log(0,repr(obj)) obj.handle_error() if hasattr(select,'poll'): loop = fastloop else: loop = slowloop class ZonefileError(Exception): def __init__(self, linenum, errordesc=''): self.linenum = linenum self.errordesc = errordesc def __str__(self): return str(self.linenum) + ' (' + self.errordesc + ')' class zonefileparser: def __init__(self): self.zonedata = {} self.dnstypes = ['A','AAAA','CNAME','HINFO','LOC','MX', 'NS','PTR','RP','SOA','SRV','TXT'] def stripcomments(self, line): i = line.find(';') if i >= 0: line = line[:i] return line def strip(self, line): # strip trailing linefeeds if line[-1:] == '\n': line = line[:-1] return line def getzdict(self): return self.zonedata def addorigin(self, origin, name): if name[-1:] != '.': return name + '.' + origin else: return name[:-1] def getstrings(self, s): if s.find('"') == -1: return s.split() else: x = s.split('"') rlist = [] for i in x: if i != '' and i != ' ': rlist.append(i) return rlist def getlocsize(self, s): if s[-1:] == 'm': size = float(s[:-1])*100 else: size = float(s)*100 i = 0 while size > 9: size = size/10 i = i + 1 return (int(size),i) def getloclat(self, l,c): deg = float(l[0]) min = 0 secs = 0 if len(l) == 3: min = float(l[1]) secs = float(l[2]) elif len(l) == 2: min = float(l[1]) rval = ((((deg *60) + min) * 60) + secs) * 1000 if c in ['N','E']: rval = rval + (2**31) elif c in ['S','W']: rval = (2**31) - rval else: log(0,'ERROR: unsupported latitude/longitude direction') return long(rval) def getgname(self, name, iter): if name == '0' or name == 'O': return '' start = 0 offset = 0 width = 0 base = 'd' for x in range(name.count('$')): i = name.find('$',start) j = i start = i+1 if i>0: if name[i-1] == '\\': continue if len(name)>i+1: if name[i+1] == '$': continue if name[i+1] == '{': j = name.find('}',i+1) owb = name[i+2:j].split(',') if len(owb) == 1: offset = int(owb[0]) elif len(owb) == 2: offset = int(owb[0]) width = int(owb[1]) elif len(owb) == 3: offset = int(owb[0]) width = int(owb[1]) base = owb[2] val = iter - offset if base == 'd': rs = str(val) elif base == 'o': rs = oct(val) elif base == 'x': rs = hex(val)[2:].lower() elif base == 'X': rs = hex(val)[2:].upper() else: rs = '' if len(rs) > width: rs = (width-len(rs))*'0'+rs name = name[:i]+rs+name[j+1:] start = i+len(rs)+1 return name def getrrdata(self, origin, dnstype, dnsclass, ttl, tokens): rdata = {} rdata['class'] = dnsclass rdata['ttl'] = ttl if dnstype == 'A': rdata['address'] = tokens[0] elif dnstype == 'AAAA': rdata['address'] = tokens[0] elif dnstype == 'CNAME': rdata['cname'] = self.addorigin(origin,tokens[0].lower()) elif dnstype == 'HINFO': sl = self.getstrings(' '.join(tokens)) rdata['cpu'] = sl[0] rdata['os'] = sl[1] elif dnstype == 'LOC': if 'N' in tokens: i = tokens.index('N') else: i = tokens.index('S') lat = self.getloclat(tokens[0:i],tokens[i]) if 'E' in tokens: j = tokens.index('E') else: j = tokens.index('W') lng = self.getloclat(tokens[i+1:j],tokens[j]) size = self.getlocsize('1m') horiz_pre = self.getlocsize('1000m') vert_pre = self.getlocsize('10m') if len(tokens[j+1:]) == 2: size = self.getlocsize(tokens[-1:][0]) elif len(tokens[j+1:]) == 3: size = self.getlocsize(tokens[-2:-1][0]) horiz_pre = self.getlocsize(tokens[-1:][0]) elif len(tokens[j+1:]) == 4: size = self.getlocsize(tokens[-3:-2][0]) horiz_pre = self.getlocsize(tokens[-2:-1][0]) vert_pre = self.getlocsize(tokens[-1:][0]) if tokens[j+1][-1:] == 'm': alt = int((float(tokens[j+1][:-1])*100)+10000000) else: size = int((float(tokens[j+1])*100)+10000000) rdata['version'] = 0 rdata['size'] = size rdata['horiz_pre'] = horiz_pre rdata['vert_pre'] = vert_pre rdata['latitude'] = lat rdata['longitude'] = lng rdata['altitude'] = 0 elif dnstype == 'MX': rdata['preference'] = int(tokens[0]) rdata['exchange'] = self.addorigin(origin,tokens[1].lower()) elif dnstype == 'NS': rdata['nsdname'] = self.addorigin(origin,tokens[0].lower()) elif dnstype == 'PTR': rdata['ptrdname'] = self.addorigin(origin,tokens[0].lower()) elif dnstype == 'RP': rdata['mboxdname'] = self.addorigin(origin,tokens[0].lower()) rdata['txtdname'] = self.addorigin(origin,tokens[1].lower()) elif dnstype == 'SOA': rdata['mname'] = self.addorigin(origin,tokens[0].lower()) rdata['rname'] = self.addorigin(origin,tokens[1].lower()) rdata['serial'] = int(tokens[2]) rdata['refresh'] = int(tokens[3]) rdata['retry'] = int(tokens[4]) rdata['expire'] = int(tokens[5]) rdata['minimum'] = int(tokens[6]) elif dnstype == 'SRV': rdata['priority'] = int(tokens[0]) rdata['weight'] = int(tokens[1]) rdata['port'] = int(tokens[2]) rdata['target'] = self.addorigin(origin,tokens[3].lower()) elif dnstype == 'TXT': rdata['txtdata'] = self.getstrings(' '.join(tokens))[0] else: raise ZonefileError(lineno,'bad DNS type') return rdata def addrec(self, owner, dnstype, rrdata): if self.zonedata.has_key(owner): if not self.zonedata[owner].has_key(dnstype): self.zonedata[owner][dnstype] = [] else: self.zonedata[owner] = {} self.zonedata[owner][dnstype] = [] self.zonedata[owner][dnstype].append(rrdata) def parse(self, origin, f): closefile = 0 if type(f) != types.FileType: # must be a path try: f = file(f) closefile = 1 except: log(0,'Invalid path to zonefile') return lastowner = '' lastdnsclass = '' lastttl = 3600 lineno = 0 while 1: line = f.readline() if not line: break lineno = lineno + 1 line = self.stripcomments(line) line = self.strip(line) if not line: continue if line.find('(') >= 0: # grab lines until end paren if line.find(')') == -1: line2 = self.stripcomments(f.readline()) lineno = lineno + 1 line2 = self.strip(line2) line = line + line2 while line2.find(')') == -1: line2 = self.strip(self.stripcomments(f.readline())) lineno = lineno + 1 line = line + line2 # now strip the parenthesis line = line.replace(')','') line = line.replace('(','') # now line equals the entire RR entry tokens = line.split() if tokens[0].upper() == '$ORIGIN': try: origin = tokens[1].lower() except: raise ZonefileError(lineno, 'bad origin') elif tokens[0].upper() == '$INCLUDE': try: f2 = file(tokens[1].lower()) if len(tokens) > 2: self.parse(tokens[2].lower(), f2) else: self.parse(origin, f2) f2.close() except: raise ZonefileError(lineno, 'bad INCLUDE directive') elif tokens[0].upper() == '$TTL': try: lastttl = int(tokens[1]) except: raise ZonefileError(lineno, 'bad TTL directive') elif tokens[0].upper() == '$GENERATE': try: lhs = tokens[2].lower() dnstype = tokens[3].upper() rhs = tokens[4].lower() rng = tokens[1].split('-') start = int(rng[0]) i = rng[1].find('/') if i != -1: stop = int(rng[1][:i])+1 step = int(rng[1][i+1:]) else: stop = int(rng[1])+1 step = 1 for i in range(start,stop,step): grhs = self.getgname(rhs,i) if dnstype in ['NS','CNAME','PTR']: grhs = self.addorigin(origin,grhs) rrdata = self.getrrdata(origin, dnstype, 'IN', lastttl, [grhs]) glhs = self.addorigin(origin,self.getgname(lhs,i)) self.addrec(glhs,dnstype, rrdata) except KeyError: raise ZonefileError(lineno, 'bad GENERATE directive') else: try: # if line begins with blank then owner is last owner if line[0] in string.whitespace: owner = lastowner else: owner = tokens[0].lower() tokens = tokens[1:] if owner == '@': owner = origin elif owner[-1:] != '.': owner = owner + '.' + origin else: owner = owner[:-1] # strip off trailing dot # line format is either: [class] [ttl] type RDATA # or [ttl] [class] type RDATA # - items in brackets are optional # # need to figure out which token is type # and backfill the missing data count = 0 for token in tokens: if token.upper() in self.dnstypes: break count = count + 1 # the following strips off the ttl and class if they exist if count == 0: ttl = lastttl dnsclass = lastdnsclass elif count == 1: if tokens[0].isdigit(): ttl = int(tokens[0]) dnsclass = lastdnsclass else: ttl = lastttl dnsclass = tokens[0].upper() tokens = tokens[1:] elif count == 2: if tokens[0].isdigit(): ttl = int(tokens[0]) dnsclass = tokens[1].upper() else: ttl = int(tokens[1]) dnsclass = tokens[0].upper() tokens = tokens[2:] else: raise ZonefileError(lineno,'bad ttl or class') dnstype = tokens[0] # make sure all of the structure is there rrdata = self.getrrdata(origin, dnstype, dnsclass, ttl, tokens[1:]) self.addrec(owner, dnstype, rrdata) lastowner = owner lastttl = ttl lastdnsclass = dnsclass except: raise ZonefileError(lineno,'unable to parse line') if closefile: f.close() class dnsconfig: def __init__(self): # self.zonedb = zonedb({}) self.cached = {} self.loglevel = 0 def getview(self, msg, address, port): # return: # 1. a list of zone keys # 2. whether or not to use the resolver # (i.e. answer recursive queries) # 3. a list of forwarder addresses return ['servers.csail.mit.edu'], 1, [] def allowupdate(self, msg, address, port): # return 1 if updates are allowed # NOTE: can only update the zones # returned by the getview func return 1 def outpackets(self, packetlist): # modify outgoing packets return packetlist class dnsheader: def __init__(self, id=1): self.id = id # 16bit identifier generated by queryer self.qr = 0 # one bit field specifying query(0) or response(1) self.opcode = 0 # 4bit field specifying type of query self.aa = 0 # authoritative answer self.tc = 0 # message is not truncated self.rd = 1 # recursion desired self.ra = 0 # recursion available? self.z = 0 # reserved for future use self.rcode = 0 # response code (set in response) self.qdcount = 1 # number of questions, only 1 is supported self.ancount = 0 # number of rrs in the answer section self.nscount = 0 # number of name server rrs in authority section self.arcount = 0 # number or rrs in the additional section class dnsquestion: def __init__(self): self.qname = 'localhost' self.qtype = 'A' self.qclass = 'IN' class dnsupdatezone: pass class message: def __init__(self, msgdata=''): if msgdata: self.header = dnsheader() else: self.header = dnsheader(id=random.randrange(1,32768)) self.question = dnsquestion() self.answerlist = [] self.authlist = [] self.addlist = [] self.u = '' self.qtypes = {1:'A',2:'NS',3:'MD',4:'MF',5:'CNAME',6:'SOA', 7:'MB',8:'MG',9:'MR',10:'NULL',11:'WKS', 12:'PTR',13:'HINFO',14:'MINFO',15:'MX', 16:'TXT',17:'RP',28:'AAAA',29:'LOC',33:'SRV', 38:'A6',39:'DNAME',251:'IXFR',252:'AXFR', 253:'MAILB',254:'MAILA',255:'ANY'} self.rqtypes = {} for key in self.qtypes.keys(): self.rqtypes[self.qtypes[key]] = key self.qclasses = {1:'IN',2:'CS',3:'CH',4:'HS',254:'NONE',255:'ANY'} self.rqclasses = {} for key in self.qclasses.keys(): self.rqclasses[self.qclasses[key]] = key if msgdata: self.processpkt(msgdata) def getdomainname(self, data, i): log(4,'IN GETDOMAINNAME') domainname = '' gotpointer = 0 labellength= ord(data[i]) log(4,'labellength:' + str(labellength)) i = i + 1 while labellength != 0: while labellength >= 192: # pointer if not gotpointer: rindex = i + 1 gotpointer = 1 log(4,'got pointer') i = asctoint(chr(ord(data[i-1]) & 63)+data[i]) log(4,'new index:'+str(i)) labellength = ord(data[i]) log(4,'labellength:' + str(labellength)) i = i + 1 if domainname: domainname = domainname + '.' + data[i:i+labellength] else: domainname = data[i:i+labellength] log(4,'domainname:'+domainname) i = i + labellength labellength = ord(data[i]) log(4,'labellength:' + str(labellength)) i = i + 1 if not gotpointer: rindex = i return domainname.lower(), rindex def getrrdata(self, type, msgdata, rdlength, i): log(4,'unpacking RR data') rdata = msgdata[i:i+rdlength] if type == 'A': return {'address':socket.inet_ntoa(rdata)} elif type == 'AAAA': return {'address':ipv6net_ntoa(rdata)} elif type == 'CNAME': cname, i = self.getdomainname(msgdata,i) return {'cname':cname} elif type == 'HINFO': cpulen = ord(rdata[0]) cpu = rdata[1:cpulen+1] return {'cpu':cpu, 'os':rdata[cpulen+2:]} elif type == 'LOC': return {'version':ord(rdata[0]), 'size':self.locsize(rdata[1]), 'horiz_pre':self.locsize(rdata[2]), 'vert_pre':self.locsize(rdata[3]), 'latitude':asctoint(rdata[4:8]), 'longitude':asctoint(rdata[8:12]), 'altitude':asctoint(rdata[12:16])} elif type == 'MX': exchange, i = self.getdomainname(msgdata,i+2) return {'preference':asctoint(rdata[:2]), 'exchange':exchange} elif type == 'NS': nsdname, i = self.getdomainname(msgdata,i) return {'nsdname':nsdname} elif type == 'PTR': ptrdname, i = self.getdomainname(msgdata,i) return {'ptrdname':ptrdname} elif type == 'RP': mboxdname, i = self.getdomainname(msgdata,i) txtdname, i = self.getdomainname(msgdata,i) return {'mboxdname':mboxdname, 'txtdname':txtdname} elif type == 'SOA': mname, i = self.getdomainname(msgdata,i) rname, i = self.getdomainname(msgdata,i) return {'mname':mname, 'rname':rname, 'serial':asctoint(msgdata[i:i+4]), 'refresh':asctoint(msgdata[i+4:i+8]), 'retry':asctoint(msgdata[i+8:i+12]), 'expire':asctoint(msgdata[i+12:i+16]), 'minimum':asctoint(msgdata[i+16:i+20])} elif type == 'SRV': target, i = self.getdomainname(msgdata,i+6) return {'priority':asctoint(rdata[0:2]), 'weight':asctoint(rdata[2:4]), 'port':asctoint(rdata[4:6]), 'target':target} elif type == 'TXT': return {'txtdata':rdata[1:]} else: return {'rdata':rdata} def getrr(self, data, i): log(4,'unpacking RR name') name, i = self.getdomainname(data, i) type = asctoint(data[i:i+2]) type = self.qtypes.get(type,chr(type)) klass = asctoint(data[i+2:i+4]) klass = self.qclasses.get(klass,chr(klass)) ttl = asctoint(data[i+4:i+8]) rdlength = asctoint(data[i+8:i+10]) rrdata = self.getrrdata(type,data,rdlength,i+10) rrdata['ttl'] = ttl rrdata['class'] = klass rr = {name:{type:[rrdata]}} return rr, i+10+rdlength def processpkt(self, msgdata): self.header.id = asctoint(msgdata[:2]) self.header.qr = ord(msgdata[2]) >> 7 self.header.opcode = (ord(msgdata[2]) & 127) >> 3 if self.header.opcode == 5: # UPDATE packet log(4,'processing UPDATE packet') del self.header.aa del self.header.tc del self.header.rd del self.header.ra del self.header.qdcount del self.header.ancount del self.header.nscount del self.header.arcount del self.question self.zone = dnsupdatezone() del self.answerlist del self.authlist del self.addlist self.header.z = 0 self.header.rcode = ord(msgdata[3]) & 15 self.header.zocount = asctoint(msgdata[4:6]) self.header.prcount = asctoint(msgdata[6:8]) self.header.upcount = asctoint(msgdata[8:10]) self.header.arcount = asctoint(msgdata[10:12]) self.zolist = [] self.prlist = [] self.uplist = [] self.addlist = [] i = 12 for x in range(self.header.zocount): (dn, i) = self.getdomainname(msgdata,i) self.zone.zname = dn type = asctoint(msgdata[i:i+2]) self.zone.ztype = self.qtypes.get(type,chr(type)) klass = asctoint(msgdata[i+2:i+4]) self.zone.zclass = self.qclasses.get(klass,chr(klass)) i = i + 4 for x in range(self.header.prcount): rr, i = self.getrr(msgdata,i) self.prlist.append(rr) for x in range(self.header.upcount): rr, i = self.getrr(msgdata,i) self.uplist.append(rr) for x in range(self.header.arcount): rr, i = self.getrr(msgdata,i) self.adlist.append(rr) else: self.header.aa = (ord(msgdata[2]) & 4) >> 2 self.header.tc = (ord(msgdata[2]) & 2) >> 1 self.header.rd = ord(msgdata[2]) & 1 self.header.ra = ord(msgdata[3]) >> 7 self.header.z = (ord(msgdata[3]) & 112) >> 4 self.header.rcode = ord(msgdata[3]) & 15 self.header.qdcount = asctoint(msgdata[4:6]) self.header.ancount = asctoint(msgdata[6:8]) self.header.nscount = asctoint(msgdata[8:10]) self.header.arcount = asctoint(msgdata[10:12]) i = 12 for x in range(self.header.qdcount): log(4,'unpacking question') (dn, i) = self.getdomainname(msgdata,i) self.question.qname = dn rrtype = asctoint(msgdata[i:i+2]) self.question.qtype = self.qtypes.get(rrtype,chr(rrtype)) klass = asctoint(msgdata[i+2:i+4]) self.question.qclass = self.qclasses.get(klass,chr(klass)) i = i + 4 for x in range(self.header.ancount): log(4,'unpacking answer RR') rr, i = self.getrr(msgdata,i) self.answerlist.append(rr) for x in range(self.header.nscount): log(4,'unpacking auth RR') rr, i = self.getrr(msgdata,i) self.authlist.append(rr) for x in range(self.header.arcount): log(4,'unpacking additional RR') rr, i = self.getrr(msgdata,i) self.addlist.append(rr) return def pds(self, s, l): # pad string with chr(0)'s so that # return string length is l x = l - len(s) return x*chr(0) + s def locsize(self, s): x1 = ord(s) >> 4 x2 = ord(s) & 15 return (x1, x2) def packlocsize(self, x): return chr((x[0] << 4) + x[1]) def packdomainname(self, name, i, msgcomp): log(4,'packing domainname: ' + name) if name == '': return chr(0) if name in msgcomp.keys(): log(4,'using pointer for: ' + name) return msgcomp[name] packedname = '' tokens = name.split('.') for j in range(len(tokens)): packedname = packedname + chr(len(tokens[j])) + tokens[j] nameleft = '.'.join(tokens[j+1:]) if nameleft in msgcomp.keys(): log(4,'using pointer for: ' + nameleft) return packedname+msgcomp[nameleft] # haven't used a pointer so put this in the dictionary pointer = inttoasc(i) if len(pointer) == 1: msgcomp[name] = chr(192)+pointer else: msgcomp[name] = chr(192|ord(pointer[0])) + pointer[1] log(4,'added pointer for ' + name + '(' + str(i) + ')') return packedname + chr(0) def packrr(self, rr, i, msgcomp): rrname = rr.keys()[0] rrtype = rr[rrname].keys()[0] if self.rqtypes.has_key(rrtype): typeval = self.rqtypes[rrtype] else: typeval = ord(rrtype) dbrec = rr[rrname][rrtype][0] ttl = dbrec['ttl'] rclass = self.rqclasses[dbrec['class']] packedrr = (self.packdomainname(rrname, i, msgcomp) + self.pds(inttoasc(typeval),2) + self.pds(inttoasc(rclass),2) + self.pds(inttoasc(ttl),4)) i = i + len(packedrr) + 2 if rrtype == 'A': rdata = socket.inet_aton(dbrec['address']) elif rrtype == 'AAAA': rdata = ipv6net_aton(dbrec['address']) elif rrtype == 'CNAME': rdata = self.packdomainname(dbrec['cname'], i, msgcomp) elif rrtype == 'HINFO': rdata = (chr(len(dbrec['cpu'])) + dbrec['cpu'] + chr(len(dbrec['os'])) + dbrec['os']) elif rrtype == 'LOC': rdata = (chr(dbrec['version']) + self.packlocsize(dbrec['size']) + self.packlocsize(dbrec['horiz_pre']) + self.packlocsize(dbrec['vert_pre']) + self.pds(inttoasc(dbrec['latitude']),4) + self.pds(inttoasc(dbrec['longitude']),4) + self.pds(inttoasc(dbrec['altitude']),4)) elif rrtype == 'MX': rdata = (self.pds(inttoasc(dbrec['preference']),2) + self.packdomainname(dbrec['exchange'], i+2, msgcomp)) elif rrtype == 'NS': rdata = self.packdomainname(dbrec['nsdname'], i, msgcomp) elif rrtype == 'PTR': rdata = self.packdomainname(dbrec['ptrdname'], i, msgcomp) elif rrtype == 'RP': rdata1 = self.packdomainname(dbrec['mboxdname'], i , msgcomp) i = i + len(rdata1) rdata2 = self.packdomainname(dbrec['mboxdname'], i , msgcomp) rdata = rdata1 + rdata2 elif rrtype == 'SOA': rdata1 = self.packdomainname(dbrec['mname'], i, msgcomp) i = i + len(rdata1) rdata2 = self.packdomainname(dbrec['rname'], i, msgcomp) rdata = (rdata1 + rdata2 + self.pds(inttoasc(dbrec['serial']),4) + self.pds(inttoasc(dbrec['refresh']),4) + self.pds(inttoasc(dbrec['retry']),4) + self.pds(inttoasc(dbrec['expire']),4) + self.pds(inttoasc(dbrec['minimum']),4)) elif rrtype == 'SRV': rdata = (self.pds(inttoasc(dbrec['priority']),2) + self.pds(inttoasc(dbrec['weight']),2) + self.pds(inttoasc(dbrec['port']),2) + self.packdomainname(dbrec['target'], i+6, msgcomp)) elif rrtype == 'TXT': rdata = chr(len(dbrec['txtdata'])) + dbrec['txtdata'] else: rdata = dbrec['rdata'] return packedrr+self.pds(inttoasc(len(rdata)),2)+rdata def buildpkt(self): # keep dictionary of names packed (so we can use pointers) msgcomp = {} # header if self.header.id > 65535: log(0,'building packet with bad ID field') self.header.id = 1 msgdata = inttoasc(self.header.id) if len(msgdata) == 1: msgdata = chr(0) + msgdata h1 = ((self.header.qr << 7) + (self.header.opcode << 3) + (self.header.aa << 2) + (self.header.tc << 1) + (self.header.rd)) h2 = ((self.header.ra << 7) + (self.header.z << 4) + (self.header.rcode)) msgdata = msgdata + chr(h1) + chr(h2) msgdata = msgdata + self.pds(inttoasc(self.header.qdcount),2) msgdata = msgdata + self.pds(inttoasc(self.header.ancount),2) msgdata = msgdata + self.pds(inttoasc(self.header.nscount),2) msgdata = msgdata + self.pds(inttoasc(self.header.arcount),2) # question msgdata = msgdata + self.packdomainname(self.question.qname, len(msgdata), msgcomp) if self.rqtypes.has_key(self.question.qtype): typeval = self.rqtypes[self.question.qtype] else: typeval = ord(self.question.qtype) msgdata = msgdata + self.pds(inttoasc(typeval),2) if self.rqclasses.has_key(self.question.qclass): classval = self.rqclasses[self.question.qclass] else: classval = ord(self.question.qclass) msgdata = msgdata + self.pds(inttoasc(classval),2) # rr's # RR record format: # {'name' : {'type' : [rdata, rdata, ...]} # example: {'test.blah.net': {'A': [{'address': '10.1.1.2', # 'ttl': 3600L}]}} for rr in self.answerlist: log(4,'packing answer RR') msgdata = msgdata + self.packrr(rr, len(msgdata), msgcomp) for rr in self.authlist: log(4,'packing auth RR') msgdata = msgdata + self.packrr(rr, len(msgdata), msgcomp) for rr in self.addlist: log(4,'packing additional RR') msgdata = msgdata + self.packrr(rr, len(msgdata), msgcomp) return msgdata def printpkt(self): print 'ID: ' +str(self.header.id) if self.header.qr: print 'QR: RESPONSE' else: print 'QR: QUERY' if self.header.opcode == 0: print 'OPCODE: STANDARD QUERY' elif self.header.opcode == 1: print 'OPCODE: INVERSE QUERY' elif self.header.opcode == 2: print 'OPCODE: SERVER STATUS REQUEST' elif self.header.opcode == 5: print 'UPDATE REQUEST' else: print 'OPCODE: UNKNOWN QUERY TYPE' if self.header.opcode != 5: if self.header.aa: print 'AA: AUTHORITATIVE ANSWER' else: print 'AA: NON-AUTHORITATIVE ANSWER' if self.header.tc: print 'TC: MESSAGE IS TRUNCATED' else: print 'TC: MESSAGE IS NOT TRUNCATED' if self.header.rd: print 'RD: RECURSION DESIRED' else: print 'RD: RECURSION NOT DESIRED' if self.header.ra: print 'RA: RECURSION AVAILABLE' else: print 'RA: RECURSION IS NOT AVAILABLE' if self.header.rcode == 1: printrcode = 'FORMERR' elif self.header.rcode == 2: printrcode = 'SERVFAIL' elif self.header.rcode == 3: printrcode = 'NXDOMAIN' elif self.header.rcode == 4: printrcode = 'NOTIMP' elif self.header.rcode == 5: printrcode = 'REFUSED' elif self.header.rcode == 6: printrcode = 'YXDOMAIN' elif self.header.rcode == 7: printrcode = 'YXRRSET' elif self.header.rcode == 8: printrcode = 'NXRRSET' elif self.header.rcode == 9: printrcode = 'NOTAUTH' elif self.header.rcode == 10: printrcode = 'NOTZONE' else: printrcode = 'NOERROR' print 'RCODE: ' + printrcode if self.header.opcode == 5: print 'NUMBER OF RRs in the Zone Section: ' + str(self.header.zocount) print 'NUMBER OF RRs in the Prerequisite Section: ' + str(self.header.prcount) print 'NUMBER OF RRs in the Update Section: ' + str(self.header.upcount) print 'NUMBER OF RRs in the Additional Data Section: ' + str(self.header.arcount) print 'ZONE SECTION:' print 'zname: ' + self.zone.zname print 'zonetype: ' + self.zone.ztype print 'zoneclass: ' + self.zone.zclass print 'PREREQUISITE RRs:' for rr in self.prlist: print rr print 'UPDATE RRs:' for rr in self.uplist: print rr print 'ADDITIONAL RRs:' for rr in self.addlist: print rr else: print 'NUMBER OF QUESTION RRs: ' + str(self.header.qdcount) print 'NUMBER OF ANSWER RRs: ' + str(self.header.ancount) print 'NUMBER OF NAME SERVER RRs: ' + str(self.header.nscount) print 'NUMBER OF ADDITIONAL RRs: ' + str(self.header.arcount) print 'QUESTION SECTION:' print 'qname: ' + self.question.qname print 'querytype: ' + self.question.qtype print 'queryclass: ' + self.question.qclass print 'ANSWER RRs:' for rr in self.answerlist: print rr print 'AUTHORITY RRs:' for rr in self.authlist: print rr print 'ADDITIONAL RRs:' for rr in self.addlist: print rr class zonedb: def __init__(self, zdict): self.zdict = zdict self.updates = {} for k in self.zdict.keys(): if self.zdict[k]['type'] == 'slave': self.zdict[k]['lastupdatetime'] = 0 def error(self, id, qname, querytype, queryclass, rcode): error = message() error.header.id = id error.header.rcode = rcode error.header.qr = 1 error.question.qname = qname error.question.qtype = querytype error.question.qclass = queryclass return error def getorigin(self, zkey): origin = '' if self.zdict.has_key(zkey): origin = self.zdict[zkey]['origin'] return origin def getmasterip(self, zkey): masterip = '' if self.zdict.has_key(zkey): if self.zdict[zkey].has_key('masterip'): masterip = self.zdict[zkey]['masterip'] return masterip def zonetrans(self, query): # build a list of messages # each message contains one rr of the zone # the first and last message are the # SOA records origin = query.question.qname querytype = query.question.qtype zkey = '' for zonekey in self.zdict.keys(): if self.zdict[zonekey]['origin'] == query.question.qname: zkey = zonekey if not zkey: return [] zonedata = self.zdict[zkey]['zonedata'] queryid = query.header.id soarec = zonedata[origin]['SOA'][0] soa = {origin:{'SOA':[soarec]}} curserial = soarec['serial'] rrlist = [] if querytype == 'IXFR': clientserial = query.authlist[0][origin]['SOA'][0]['serial'] if clientserial < curserial: for i in range(clientserial,curserial+1): if self.updates[zkey].has_key(i): for rr in self.updates[zkey][i]['added']: rrlist.append(rr) for rr in self.updates[zkey][i]['removed']: rrlist.append(rr) if len(rrlist) > 0: rrlist.insert(0,soa) rrlist.append(soa) else: rrlist.append(soa) else: for nodename in zonedata.keys(): for rrtype in zonedata[nodename].keys(): if not (rrtype == 'SOA' and nodename == origin): for rr in zonedata[nodename][rrtype]: rrlist.append({nodename:{rrtype:[rr]}}) rrlist.insert(0,soa) rrlist.append(soa) msglist = [] for rr in rrlist: msg = message() msg.header.id = queryid msg.header.qr = 1 msg.header.aa = 1 msg.header.rd = 0 msg.header.qdcount = 1 msg.question.qname = origin msg.question.qtype = querytype msg.question.qclass = 'IN' msg.header.ancount = 1 msg.answerlist.append(rr) msglist.append(msg) return msglist def update_zone(self, rrlist, params): zonekey = params[0] zonedata = {} soa = rrlist.pop() origin = soa.keys()[0] for rr in rrlist: rrname = rr.keys()[0] rrtype = rr[rrname].keys()[0] dbrec = rr[rrname][rrtype][0] if zonedata.has_key(rrname): if not zonedata[rrname].has_key(rrtype): zonedata[rrname][rrtype] = [] else: zonedata[rrname] = {} zonedata[rrname][rrtype] = [] zonedata[rrname][rrtype].append(dbrec) self.zdict[zonekey]['zonedata'] = zonedata curtime = time.time() self.zdict[zonekey]['lastupdatetime'] = curtime try: f = file(self.zdict[zonekey]['filename'],'w') writezonefile(zonedata, self.zdict[zonekey]['origin'], f) f.close() except: log(0,'unable to write zone ' + zonekey + 'to disk') log(1,'finished zone transfer for: ' + zonekey + ' (' + str(curtime) + ')') def remove_zone(self, zonekey): if self.zdict.has_key(zonekey): del self.zdict[zonekey] def getslaves(self, curtime): rlist = [] for k in self.zdict.keys(): if self.zdict[k]['type'] == 'slave': origin = self.zdict[k]['origin'] refresh = self.zdict[k]['zonedata'][origin]['SOA'][0]['refresh'] if self.zdict[k]['lastupdatetime'] + refresh < curtime: rlist.append((k, origin, self.zdict[k]['masterip'])) return rlist def zmatch(self, qname, zkeys): for zkey in zkeys: if self.zdict.has_key(zkey): origin = self.zdict[zkey]['origin'] if qname.rfind(origin) != -1: return zkey return '' def getzlist(self, name, zone): if name == zone: return zlist = [] i = name.rfind(zone) if i == -1: return firstpart = name[:i-1] partlist = firstpart.split('.') partlist.reverse() lastpart = zone for x in range(len(partlist)): lastpart = partlist[x] + '.' + lastpart zlist.append(lastpart) return zlist def lookup(self, zkeys, query, addr, server, dorecursion, flist, cbfunc): # handle zone transfers seperately qname = query.question.qname querytype = query.question.qtype queryclass = query.question.qclass if querytype in ['AXFR','IXFR']: for zkey in self.zdict.keys(): if zkey in zkeys: if qname == self.zdict[zkey]['origin']: answerlist = self.zonetrans(query) break else: answerlist = [] cbfunc(query, addr, server, dorecursion, flist, answerlist) else: zonekey = self.zmatch(qname, zkeys) if zonekey: origin = self.zdict[zonekey]['origin'] zonedict = self.zdict[zonekey]['zonedata'] referral = 0 rranswerlist = [] rrnslist = [] rraddlist = [] answer = message() answer.header.aa = 1 answer.header.id = query.header.id answer.header.qr = 1 answer.header.opcode = query.header.opcode answer.header.rcode = 4 answer.header.ra = dorecursion answer.question.qname = query.question.qname answer.question.qtype = query.question.qtype answer.question.qclass = query.question.qclass answer.header.ra = dorecursion s = '.servers.csail.mit.edu' if qname.endswith(s): host = qname[:-len(s)] value = sipb_xen_database.NIC.get_by(hostname=host) if value is None: pass else: ip = value.ip rranswerlist.append({qname: {'A': [{'address': ip, 'class': 'IN', 'ttl': 10}]}}) if zonedict.has_key(qname): # found the node, now take care of CNAMEs if zonedict[qname].has_key('CNAME'): if querytype != 'CNAME': nodetype = 'CNAME' while nodetype == 'CNAME': rranswerlist.append({qname:{'CNAME':[zonedict[qname]['CNAME'][0]]}}) qname = zonedict[qname]['CNAME'][0]['cname'] if zonedict.has_key(qname): nodetype = zonedict[qname].keys()[0] else: # error, shouldn't have a CNAME that points to nothing return # if we get this far, then the record has matched and we should return # a reply that has no error (even if there is no info macthing the qtype) answer.header.rcode = 0 answernode = zonedict[qname] if querytype == 'ANY': for type in answernode.keys(): for rec in answernode[type]: rranswerlist.append({qname:{type:[rec]}}) elif answernode.has_key(querytype): for rec in answernode[querytype]: rranswerlist.append({qname:{querytype:[rec]}}) # do rrset ordering (cyclic) if len(answernode[querytype]) > 1: rec = answernode[querytype].pop(0) answernode[querytype].append(rec) else: # remove all cname rrs from answerlist rranswerlist = [] else: # would check for wildcards here (but aren't because they seem bad) # see if we need to give a referral zlist = self.getzlist(qname,origin) for zonename in zlist: if zonedict.has_key(zonename): if zonedict[zonename].has_key('NS'): answer.header.rcode = 0 referral = 1 for rec in zonedict[zonename]['NS']: rrnslist.append({zonename:{'NS':[rec]}}) nsdname = rec['nsdname'] # add glue records if they exist if zonedict.has_key(nsdname): if zonedict[nsdname].has_key('A'): for gluerec in zonedict[nsdname]['A']: rraddlist.append({nsdname:{'A':[gluerec]}}) # negative caching stuff if not referral: if not rranswerlist: # NOTE: RFC1034 section 4.3.4 says we should add the SOA record # to the additional section of the response. BIND adds # it to the ns section though answer.header.rcode = 3 rrnslist.append({origin:{'SOA':[zonedict[origin]['SOA'][0]]}}) else: for rec in zonedict[origin]['NS']: rrnslist.append({origin:{'NS':[rec]}}) answer.header.ancount = len(rranswerlist) answer.header.nscount = len(rrnslist) answer.header.arcount = len(rraddlist) answer.answerlist = rranswerlist answer.authlist = rrnslist answer.addlist = rraddlist cbfunc(query, addr, server, dorecursion, flist, [answer]) else: cbfunc(query, addr, server, dorecursion, flist, []) def handle_update(self, msg, addr, ns): zkey = '' slaves = [] for zonekey in self.zdict.keys(): if (self.zdict[zonekey]['type'] == 'master' and self.zdict[zonekey]['origin'] == msg.zone.zname): zkey = zonekey if not zkey: log(2,'SENDING NOTAUTH UPDATE ERROR') errormsg = self.error(msg.header.id, msg.zone.zname, msg.zone.ztype, msg.zone.zclass, 9) return errormsg, '', slaves # find the slaves for the zone if self.zdict[zkey].has_key('slaves'): slaves = self.zdict[zkey]['slaves'] origin = self.zdict[zkey]['origin'] zd = self.zdict[zkey]['zonedata'] # check the permissions if not ns.config.allowupdate(msg, addr[0], addr[1]): log(2,'SENDING REFUSED UPDATE ERROR') errormsg = self.error(msg.header.id, msg.zone.zname, msg.zone.ztype, msg.zone.zclass, 5) return errormsg, origin, slaves # now check the prereqs temprrset = {} for rr in msg.prlist: rrname = rr.keys()[0] rrtype = rr[rrname].keys()[0] dbrec = rr[rrname][rrtype][0] if dbrec['ttl'] != 0: log(2,'FORMERROR(1)') errormsg = self.error(msg.header.id, msg.zone.zname, msg.zone.ztype, msg.zone.zclass, 1) return errormsg, origin, slaves if rrname.rfind(msg.zone.zname) == -1: log(2,'NOTZONE(10)') errormsg = self.error(msg.header.id, msg.zone.zname, msg.zone.ztype, msg.zone.zclass, 10) return errormsg, origin, slaves if dbrec['class'] == 'ANY': if dbrec['rdata']: log(2,'FORMERROR(1)') errormsg = self.error(msg.header.id, msg.zone.zname, msg.zone.ztype, msg.zone.zclass, 1) return errormsg, origin, slaves if rrtype == 'ANY': if not zd.has_key(rrname): log(2,'NXDOMAIN(3)') errormsg = self.error(msg.header.id, msg.zone.zname, msg.zone.ztype, msg.zone.zclass, 3) return errormsg, origin, slaves else: rrsettest = 0 if zd.has_key(rrname): if zd[rrname].has_key(rrtype): rrsettest = 1 if not rrsettest: log(2,'NXRRSET(8)') errormsg = self.error(msg.header.id, msg.zone.zname, msg.zone.ztype, msg.zone.zclass, 8) return errormsg, origin, slaves if dbrec['class'] == 'NONE': if dbrec['rdata']: log(2,'FORMERROR(1)') errormsg = self.error(msg.header.id, msg.zone.zname, msg.zone.ztype, msg.zone.zclass, 1) return errormsg, origin, slaves if rrtype == 'ANY': if zd.has_key(rrname): log(2,'YXDOMAIN(6)') errormsg = self.error(msg.header.id, msg.zone.zname, msg.zone.ztype, msg.zone.zclass, 6) return errormsg, origin, slaves else: if zd.has_key(rrname): if zd[rrname].has_key(rrtype): log(2,'YXRRSET(7)') errormsg = self.error(msg.header.id, msg.zone.zname, msg.zone.ztype, msg.zone.zclass, 7) return errormsg, origin, slaves if dbrec['class'] == msg.zone.zclass: if temprrset.has_key(rrname): if not temprrset[rrname].has_key(rrtype): temprrset[rrname][rrtype] = [] else: temprrset[rrname] = {} temprrset[rrname][rrtype] = [] temprrset[rrname][rrtype].append(dbrec) else: log(2,'FORMERROR(1)') errormsg = self.error(msg.header.id, msg.zone.zname, msg.zone.ztype, msg.zone.zclass, 1) return errormsg, origin, slaves for nodename in temprrset.keys(): if not self.rrmatch(temprrset[nodename],zd[nodename]): log(2,'NXRRSET(8)') errormsg = self.error(msg.header.id, msg.zone.zname, msg.zone.ztype, msg.zone.zclass, 8) return errormsg, origin, slaves # update section prescan for rr in msg.uplist: rrname = rr.keys()[0] rrtype = rr[rrname].keys()[0] dbrec = rr[rrname][rrtype][0] if rrname.rfind(msg.zone.zname) == -1: log(2,'NOTZONE(10)') errormsg = self.error(msg.header.id, msg.zone.zname, msg.zone.ztype, msg.zone.zclass, 10) return errormsg, origin, slaves if dbrec['class'] == msg.zone.zclass: if rrtype in ['ANY','MAILA','MAILB','AXFR']: log(2,'FORMERROR(1)') errormsg = self.error(msg.header.id, msg.zone.zname, msg.zone.ztype, msg.zone.zclass, 1) return errormsg, origin, slaves elif dbrec['class'] == 'ANY': if dbrec['ttl'] != 0 or dbrec['rdata'] or rrtype in ['MAILA','MAILB','AXFR']: log(2,'FORMERROR(1)') errormsg = self.error(msg.header.id, msg.zone.zname, msg.zone.ztype, msg.zone.zclass, 1) return errormsg, origin, slaves elif dbrec['class'] == 'NONE': if dbrec['ttl'] != 0 or rrtype in ['ANY','MAILA','MAILB','AXFR']: log(2,'FORMERROR(1)') errormsg = self.error(msg.header.id, msg.zone.zname, msg.zone.ztype, msg.zone.zclass, 1) return errormsg, origin, slaves else: log(2,'FORMERROR(1)') errormsg = self.error(msg.header.id, msg.zone.zname, msg.zone.ztype, msg.zone.zclass, 1) return errormsg, origin, slaves # now handle actual update curserial = zd[msg.zone.zname]['SOA'][0]['serial'] # update the soa serial here clearupdatehist = 0 if len(msg.uplist) > 0: # initialize history structure if not self.updates.has_key(zkey): self.updates[zkey] = {} self.updates[zkey][curserial] = {'removed':[], 'added':[]} if curserial == 2**32: newserial = 2 clearupdatehist = 1 else: newserial = curserial + 1 self.updates[zkey][newserial] = {'removed':[], 'added':[]} zd[msg.zone.zname]['SOA'][0]['serial'] = newserial for rr in msg.uplist: rrname = rr.keys()[0] rrtype = rr[rrname].keys()[0] dbrec = rr[rrname][rrtype][0] if dbrec['class'] == msg.zone.zclass: if rrtype == 'SOA': if zd.has_key(rrname): if zd[rrname].has_key('SOA'): if dbrec['serial'] > zd[rrname]['SOA'][0]['serial']: del zd[rrname]['SOA'][0] zd[rrname]['SOA'].append(dbrec) clearupdatehist = 1 elif rrtype == 'WKS': if zd.has_key(rrname): if zd[rrname].has_key('WKS'): rdata = zd[rrname]['WKS'][0] oldrr = {rrname:{'WKS':[rdata]}} self.updates[zkey][curserial]['removed'].append(oldrr) del zd[rrname]['WKS'][0] zd[rrname]['WKS'].append(dbrec) newrr = {rrname:{'WKS':[dbrec]}} self.updates[zkey][newserial]['added'].append(newrr) else: if zd.has_key(rrname): if not zd[rrname].has_key(rrtype): zd[rrname][rrtype] = [] else: zd[rrname] = {} zd[rrname][rrtype] = [] zd[rrname][rrtype].append(dbrec) newrr = {rrname:{rrtype:[dbrec]}} self.updates[zkey][newserial]['added'].append(newrr) elif dbrec['class'] == 'ANY': if rrtype == 'ANY': if rrname == msg.zone.zname: if zd.has_key(rrname): for dnstype in zd[rrname].keys(): if dnstype not in ['SOA','NS']: for rdata in zd[rrname][dnstype]: oldrr = {rrname:{dnstype:[rdata]}} self.updates[zkey][curserial]['removed'].append(oldrr) del zd[rrname][dnstype] else: if zd.has_key(rrname): for dnstype in zd[rrname].keys(): for rdata in zd[rrname][dnstype]: oldrr = {rrname:{dnstype:[rdata]}} self.updates[zkey][curserial]['removed'].append(oldrr) del zd[rrname] else: if zd.has_key(rrname): if zd[rrname].has_key(rrtype): if rrname == msg.zone.zname: if rrtype not in ['SOA','NS']: for rdata in zd[rrname][dnstype]: oldrr = {rrname:{dnstype:[rdata]}} self.updates[zkey][curserial]['removed'].append(oldrr) del zd[rrname][rrtype] else: for rdata in zd[rrname][dnstype]: oldrr = {rrname:{dnstype:[rdata]}} self.updates[zkey][curserial]['removed'].append(oldrr) del zd[rrname][rrtype] elif dbrec['class'] == 'NONE': if not (rrname == msg.zone.zname and rrtype in ['SOA','NS']): if zd.had_key(rrname): if zd[rrname].has_key(rrtype): for i in range(len(zd[rrname][rrtype])): if dbrec == zd[rrname][rrtype][i]: rdata = zd[rrname][dnstype][i] oldrr = {rrname:{dnstype:[rdata]}} self.updates[zkey][curserial]['removed'].append(oldrr) del zd[rrname][rrtype][i] if len(zd[rrname][rrtype]) == 0: del zd[rrname][rrtype] if clearupdatehist: self.updates[zkey] = {} log(2,'SENDING UPDATE NOERROR MSG') noerrormsg = self.error(msg.header.id, msg.zone.zname, msg.zone.ztype, msg.zone.zclass, 0) return noerrormsg, origin, slaves class dnscache: def __init__(self,cachezone): self.cachedb = cachezone # go through and set all of the root ttls to zero for node in self.cachedb.keys(): for rtype in self.cachedb[node].keys(): for rr in self.cachedb[node][rtype]: rr['ttl'] = 0 if rtype == 'NS': rr['rtt'] = 0 # add special entries for localhost self.cachedb['localhost'] = {'A':[{'address':'127.0.0.1', 'ttl':0, 'class':'IN'}]} self.cachedb['1.0.0.127.in-addr.arpa'] = {'PTR':[{'ptrdname':'localhost', 'ttl':0,'class':'IN'}]} self.cachedb['']['SOA'] = [] self.cachedb['']['SOA'].append({'class':'IN','ttl':0,'mname':'cachedb', 'rname':'cachedb@localhost','serial':1,'refresh':10800, 'retry':3600,'expire':604800,'minimum':3600}) def hasrdata(self, irrdata, rrdatalist): # compare everything but ttls test = 0 testrrdata = irrdata.copy() del testrrdata['ttl'] for rrdata in rrdatalist: temprrdata = rrdata.copy() del temprrdata['ttl'] if temprrdata == testrrdata: test = 1 return test def add(self, rr, qzone, nsdname): # NOTE: can't cache records from sites # that don't own those records (i.e. example.com # can't give us A records for www.example.net) name = rr.keys()[0] if (qzone != '') and (name[-len(qzone):] != qzone): log(2,'cache GOT possible POISON: ' + name + ' for zone ' + qzone) return rtype = rr[name].keys()[0] rdata = rr[name][rtype][0] if rdata['ttl'] < 3600: log(2,'low ttl: ' + str(rdata['ttl'])) rdata['ttl'] = 3600 rdata['ttl'] = int(time.time() + rdata['ttl']) if rtype == 'NS': rdata['rtt'] = 0 name = name.lower() rtype = rtype.upper() if self.cachedb.has_key(name): if self.cachedb[name].has_key(rtype): if not self.hasrdata(rdata, self.cachedb[name][rtype]): self.cachedb[name][rtype].append(rdata) log(3,'appended rdata to ' + name + '(' + rtype + ') in cache') else: log(3,'same rdata for ' + name + '(' + rtype + ') is already in cache') else: self.cachedb[name][rtype] = [rdata] log(3,'appended ' + rtype + ' and rdata to node ' + name + ' in cache') else: self.cachedb[name] = {rtype:[rdata]} log(3,'added node ' + name + '(' + rtype + ') to cache') self.reap() def addneg(self, qname, querytype, queryclass): if not self.cachedb.has_key(qname): self.cachedb['qname'] = {querytype: [{'ttl':time.time()+3600}]} else: if not self.cachedb[qname].has_key(querytype): self.cachedb[qname][querytype] = [{'ttl':time.time()+3600}] def haskey(self, qname, querytype, msg=''): log(3,'looking for ' + qname + '(' + querytype + ') in cache') if self.cachedb.has_key(qname): rranswerlist = [] rrnslist = [] rraddlist = [] if self.cachedb[qname].has_key('CNAME'): if querytype != 'CNAME': nodetype = 'CNAME' while nodetype == 'CNAME': if len(self.cachedb[qname]['CNAME'][0].keys()) > 1: log(3,'Adding CNAME to cache answer') rranswerlist.append({qname:{'CNAME':[self.cachedb[qname]['CNAME'][0]]}}) qname = self.cachedb[qname]['CNAME'][0]['cname'] if self.cachedb.has_key(qname): nodetype = self.cachedb[qname].keys()[0] else: # shouldn't have a CNAME that points to nothing return if querytype == 'ANY': for type in self.cache[qname].keys(): for rec in self.cachedb[qname][type]: # can't append negative entries if len(rec.keys()) > 1: rranswerlist.append({qname:{type:[rec]}}) elif self.cachedb[qname].has_key(querytype): for rec in self.cachedb[qname][querytype]: if len(rec.keys()) > 1: rranswerlist.append({qname:{querytype:[rec]}}) if rranswerlist: if msg: answer = message() answer.header.id = msg.header.id answer.header.qr = 1 answer.header.opcode = msg.header.opcode answer.header.ra = 1 answer.question.qname = msg.question.qname answer.question.qtype = msg.question.qtype answer.question.qclass = msg.question.qclass answer.header.rcode = 0 answer.header.ancount = len(rranswerlist) answer.answerlist = rranswerlist return answer else: return 1 else: log(3,'Cache has no node for ' + qname) def getnslist(self, qname): # find the best nameserver to ask from the cache tokens = qname.split('.') nsdict = {} curtime = time.time() for i in range(len(tokens)): domainname = '.'.join(tokens[i:]) if self.cachedb.has_key(domainname): if self.cachedb[domainname].has_key('NS'): for nsrec in self.cachedb[domainname]['NS']: badserver = 0 if nsrec.has_key('badtill'): if nsrec['badtill'] < curtime: del nsrec['badtill'] else: badserver = 1 if badserver: log(2,'BAD SERVER, not using ' + nsrec['nsdname']) if self.cachedb.has_key(nsrec['nsdname']) and not badserver: if self.cachedb[nsrec['nsdname']].has_key('A'): for arec in self.cachedb[nsrec['nsdname']]['A']: nsdict[nsrec['rtt']] = {'name':nsrec['nsdname'], 'ip':arec['address']} if nsdict: break if not nsdict: domainname = '' # nothing in the cache matches so give back the root servers for nsrec in self.cachedb['']['NS']: badserver = 0 if nsrec.has_key('badtill'): if curtime > nsrec['badtill']: del nsrec['badtill'] else: badserver = 1 if not badserver: for arec in self.cachedb[nsrec['nsdname']]['A']: nsdict[(nsrec['rtt'])] = {'name':nsrec['nsdname'],'ip':arec['address']} return (domainname, nsdict) def badns(self, zonename, nsdname): if self.cachedb.has_key(zonename): if self.cachedb[zonename].has_key('NS'): for nsrec in self.cachedb[zonename]['NS']: if nsrec['nsdname'] == nsdname: log(2,'Setting ' + nsdname + ' as bad nameserver') nsrec['badtill'] = time.time() + 3600 def updatertt(self, qname, zone, rtt): if self.cachedb.has_key(zone): if self.cachedb[zone].has_key('NS'): for rr in self.cachedb[zone]['NS']: if rr['nsdname'] == qname: log(2,'updating rtt for ' + qname + ' to ' + str(rtt)) rr['rtt'] = rtt def reap(self): # expire all old records ntime = time.time() for nodename in self.cachedb.keys(): for rrtype in self.cachedb[nodename].keys(): for rdata in self.cachedb[nodename][rrtype]: ttl = rdata['ttl'] if ttl != 0: if ttl < ntime: self.cachedb[nodename][rrtype].remove(rdata) if len(self.cachedb[nodename][rrtype]) == 0: del self.cachedb[nodename][rrtype] if len(self.cachedb[nodename]) == 0: del self.cachedb[nodename] return def zonetrans(self, queryid): # build a list of messages # each message contains one rr of the zone # the first and last message are the # SOA records zonedata = self.cachedb rrlist = [] soa = {'':{'SOA':[zonedata['']['SOA'][0]]}} for nodename in zonedata.keys(): for rrtype in zonedata[nodename].keys(): if not (rrtype == 'SOA' and nodename == ''): for rr in zonedata[nodename][rrtype]: rrlist.append({nodename:{rrtype:[rr]}}) rrlist.insert(0,soa) rrlist.append(soa) msglist = [] for rr in rrlist: msg = message() msg.header.id = queryid msg.header.qr = 1 msg.header.aa = 1 msg.header.rd = 0 msg.header.qdcount = 1 msg.question.qname = 'cache' msg.question.qtype = 'AXFR' msg.question.qclass = 'IN' msg.header.ancount = 1 msg.answerlist.append(rr) msglist.append(msg) return msglist class gethostaddr(asyncore.dispatcher): def __init__(self, hostname, cbfunc, serveraddr='127.0.0.1'): asyncore.dispatcher.__init__(self) self.msg = message() self.msg.question.qname = hostname self.msg.question.qtype = 'A' self.cbfunc = cbfunc self.create_socket(socket.AF_INET, socket.SOCK_DGRAM) self.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 200 * 1024) self.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 200 * 1024) self.socket.sendto(self.msg.buildpkt(), (serveraddr,53)) def handle_read(self): replydata, addr = self.socket.recvfrom(1500) self.close() try: replymsg = message(replydata) except: log(0,'unable to process packet') return answername = replymsg.question.qname cname = '' # go through twice to catch cnames after A recs for rr in replymsg.answerlist: rrname = rr.keys()[0] rrtype = rr[rrname].keys()[0] dbrec = rr[rrname][rrtype][0] if rrname == answername and rrtype == 'CNAME': answername = dbrec['cname'] cname = answername for rr in replymsg.answerlist: rrname = rr.keys()[0] rrtype = rr[rrname].keys()[0] dbrec = rr[rrname][rrtype][0] if rrname == answername and rrtype == 'A': self.cbfunc(dbrec['address']) return # if we got a cname and no A send query for cname if cname: self.msg = message() self.msg.question.qname = cname self.msg.question.qtype = 'A' self.socket.sendto(self.msg.buildpkt(), (serveraddr,53)) else: self.cbfunc('') def writable(self): return 0 def handle_write(self): pass def handle_connect(self): pass def handle_close(self): self.close() def log_info (self, message, type='info'): if __debug__ or type != 'info': log(0,'%s: %s' % (type, message)) class simpleudprequest(asyncore.dispatcher): def __init__(self, msg, cbfunc, serveraddr='127.0.0.1', outqkey=''): asyncore.dispatcher.__init__(self) self.gotanswer = 0 self.msg = msg self.cbfunc = cbfunc self.create_socket(socket.AF_INET, socket.SOCK_DGRAM) self.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 200 * 1024) self.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 200 * 1024) self.outqkey = outqkey self.socket.sendto(self.msg.buildpkt(), (serveraddr,53)) def handle_read(self): replydata, addr = self.socket.recvfrom(1500) self.close() try: replymsg = message(replydata) except: log(0,'unable to process packet') return self.cbfunc(replymsg, self.outqkey) def writable(self): return 0 def handle_write(self): pass def handle_connect(self): pass def handle_close(self): self.close() def log_info (self, message, type='info'): if __debug__ or type != 'info': log(0,'%s: %s' % (type, message)) class simpletcprequest(asyncore.dispatcher): def __init__(self, msg, cbfunc, cbparams=[], serveraddr='127.0.0.1', errorfunc=''): asyncore.dispatcher.__init__(self) self.create_socket(socket.AF_INET, socket.SOCK_STREAM) self.query = msg self.cbfunc = cbfunc self.cbparams = cbparams self.errorfunc = errorfunc msgdata = msg.buildpkt() ml = inttoasc(len(msgdata)) if len(ml) == 1: ml = chr(0) + ml self.buffer = ml+msgdata self.rbuffer = '' self.rmsgleft = 0 self.rrlist = [] log(2,'sending tcp request to ' + serveraddr) self.connect((serveraddr,53)) def recv (self, buffer_size): try: data = self.socket.recv (buffer_size) if not data: # a closed connection is indicated by signaling # a read condition, and having recv() return 0. self.handle_close() return '' else: return data except socket.error, why: # winsock sometimes throws ENOTCONN if why[0] in [ECONNRESET, ENOTCONN, ESHUTDOWN, ETIMEDOUT]: self.handle_close() return '' else: raise socket.error, why def handle_connect(self): pass def handle_msg(self, msg): if self.query.question.qtype == 'AXFR': if len(self.rrlist) == 0: if len(msg.answerlist) == 0: if self.errorfunc: self.errorfunc(self.cbparams[0]) self.close() return rr = msg.answerlist[0] rrname = rr.keys()[0] rrtype = rr[rrname].keys()[0] self.rrlist.append(rr) if rrtype == 'SOA' and len(self.rrlist) > 1: self.close() if self.cbparams: self.cbfunc(self.rrlist, self.cbparams) else: self.cbfunc(self.rrlist) else: self.close() if self.cbparams: self.cbfunc(msg, self.cbparams) else: self.cbfunc(msg) def handle_read(self): data = self.recv(8192) if len(self.rbuffer) == 0: self.rmsglength = asctoint(data[:2]) data = data[2:] self.rbuffer = self.rbuffer + data while len(self.rbuffer) >= self.rmsglength and self.rmsglength != 0: msgdata = self.rbuffer[:self.rmsglength] self.rbuffer = self.rbuffer[self.rmsglength:] if len(self.rbuffer) == 0: self.rmsglength = 0 else: self.rmsglength = asctoint(self.rbuffer[:2]) self.rbuffer = self.rbuffer[2:] try: self.handle_msg(message(msgdata)) except: return def writable(self): return (len(self.buffer) > 0) def handle_write(self): sent = self.send(self.buffer) self.buffer = self.buffer[sent:] def handle_close(self): if self.errorfunc: self.errorfunc(self.query.question.qname) self.close() def log_info (self, message, type='info'): if __debug__ or type != 'info': log(0,'%s: %s' % (type, message)) class udpdnsserver(asyncore.dispatcher): def __init__(self, port, dnsserver): asyncore.dispatcher.__init__(self) self.create_socket(socket.AF_INET, socket.SOCK_DGRAM) self.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 200 * 1024) self.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 200 * 1024) self.bind(('',port)) self.dnsserver = dnsserver self.maxmsgsize = 500 def handle_read(self): try: while 1: msgdata, addr = self.socket.recvfrom(1500) self.dnsserver.handle_packet(msgdata, addr, self) except socket.error, why: if why[0] != asyncore.EWOULDBLOCK: raise socket.error, why def sendpackets(self, msglist, addr): for msg in msglist: msgdata = msg.buildpkt() if len(msgdata) > self.maxmsgsize: msg.header.tc = 1 # take off all the answers to ensure # the packet size is small enough msg.header.ancount = 0 msg.header.nscount = 0 msg.header.arcount = 0 msg.answerlist = [] msg.authlist = [] msg.addlist = [] msgdata = msg.buildpkt() self.sendto(msgdata, addr) def writable(self): return 0 def handle_write(self): pass def handle_connect(self): pass def handle_close(self): # print '1:In handle close' return def log_info (self, message, type='info'): if __debug__ or type != 'info': log(0,'%s: %s' % (type, message)) class tcpdnschannel(asynchat.async_chat): def __init__(self, server, s, addr): asynchat.async_chat.__init__(self, s) self.server = server self.addr = addr self.set_terminator(None) self.databuffer = '' self.msglength = 0 log(3,'Created new tcp channel') def collect_incoming_data(self, data): if self.msglength == 0: self.msglength = asctoint(data[:2]) data = data[2:] self.databuffer = self.databuffer + data if len(self.databuffer) == self.msglength: # got entire message self.server.dnsserver.handle_packet(self.databuffer, self.addr, self) self.databuffer = '' def sendpackets(self, msglist, addr): for msg in msglist: x = msg.buildpkt() ml = inttoasc(len(x)) if len(ml) == 1: ml = chr(0) + ml self.push(ml+x) self.close() def log_info (self, message, type='info'): if __debug__ or type != 'info': log(0,'%s: %s' % (type, message)) class tcpdnsserver(asyncore.dispatcher): def __init__(self, port, dnsserver): asyncore.dispatcher.__init__(self) self.create_socket(socket.AF_INET, socket.SOCK_STREAM) self.set_reuse_addr() self.bind(('',port)) self.listen(5) self.dnsserver = dnsserver def handle_accept(self): conn, addr = self.accept() tcpdnschannel(self, conn, addr) def handle_close(self): self.close() def log_info (self, message, type='info'): if __debug__ or type != 'info': log(0,'%s: %s' % (type, message)) class nameserver: def __init__(self, resolver, localconfig): self.resolver = resolver self.config = localconfig self.zdb = self.config.zonedatabase self.last_reap_time = time.time() self.maint_int = 10 self.slavesupdating = [] self.notifys = [] self.sentnotify = [] self.notify_retry_time = 30 self.notify_retries = 4 self.askedsoa = {} self.soatimeout = 10 def error(self, id, qname, querytype, queryclass, rcode): error = message() error.header.id = id error.header.rcode = rcode error.header.qr = 1 error.question.qname = qname error.question.qtype = querytype error.question.qclass = queryclass return error def need_zonetransfer(self, zkey, origin, masterip, trynum=0): self.askedsoa[zkey] = {'masterip':masterip, 'senttime':time.time(), 'origin':origin, 'trynum':trynum+1} query = message() query.header.id = random.randrange(1,32768) query.header.rd = 0 query.question.qname = origin query.question.qtype = 'SOA' query.question.qclass = 'IN' log(3,'slave checking for new data in ' + origin) simpleudprequest(query, self.handle_soaquery, masterip, zkey) def handle_soaquery(self, msg, zkey): origin = msg.question.qname masterip = self.askedsoa[zkey]['masterip'] del self.askedsoa[zkey] if zkey not in self.slavesupdating: self.slavesupdating.append(zkey) query = message() query.header.id = random.randrange(1,32768) query.header.rd = 0 query.question.qname = origin query.question.qtype = 'AXFR' query.question.qclass = 'IN' log(3,'Updating slave zone: ' + zkey) simpletcprequest(query, self.handle_zonetrans, [zkey],masterip,self.handle_zterror) def handle_zonetrans(self, rrlist, params): log(1,'handling zone transfer') zonekey = params[0] self.zdb.update_zone(rrlist, params) self.slavesupdating.remove(zonekey) def handle_zterror(self, zonekey): self.slavesupdating.remove(zonekey) self.zdb.remove_zone(zonekey) def rrmatch(self, rrset1, rrset2): for rrtype in rrset1.keys(): if rrtype not in rrset2.keys(): return else: if len(rrset1[rrtype]) != len(rrset2[rrtype]): return return 1 def process_notify(self, msg, ipaddr, port): (zkeys, dorecursion, flist) = self.config.getview(msg, ipaddr, port) goodzkey = '' for zkey in zkeys: origin = self.zdb.getorigin(zkey) if origin == msg.question.qname: masterip = self.zdb.getmasterip(zkey) if masterip: goodzkey = zkey if goodzkey: log(3,'got NOTIFY from ' + masterip) self.need_zonetransfer(goodzkey, origin, masterip, 0) return def notify(self): curtime = time.time() for origin, ipaddr, trynum, senttime in self.sentnotify: if senttime + self.notify_retry_time > curtime: self.notifys.append((origin, ipaddr, trynum)) self.sentnotify.remove((origin, ipaddr, trynum, senttime)) for origin, ipaddr, trynum in self.notifys: msg = message() msg.question.qname = origin msg.question.qtype = 'SOA' msg.question.qclass = 'IN' msg.header.opcode = 4 # there probably is a better way to do this if self.resolver: self.resolver.send_to([msg],(ipaddr,53)) if trynum+1 <= self.notify_retries: self.sentnotify.append((origin,ipaddr,trynum+1,curtime)) self.notifys = [] def handle_packet(self, msgdata, addr, server): # self.reap() try: msg = message(msgdata) except: return # find a matching view (zkeys, dorecursion, flist) = self.config.getview(msg, addr[0], addr[1]) if not msg.header.qr and msg.header.opcode == 5: log(2,'GOT UPDATE PACKET') # check the zone section if (msg.header.zocount != 1 or msg.zone.ztype != 'SOA' or msg.zone.zclass != 'IN'): log(2,'SENDING FORMERR UPDATE ERROR') errormsg = self.error(msg.header.id, msg.zone.zname, msg.zone.ztype, msg.zone.zclass, 1) server.sendpackets([errormsg],addr) else: (answer, origin, slaves) = self.zdb.handle_update(msg, addr, self) if answer.header.rcode == 0: # schedule NOTIFYs to slaves for ipaddr in slaves: self.notifys.append((origin, ipaddr, 0)) server.sendpackets([answer],addr) elif msg.header.opcode == 4: if msg.header.qr: log(0,'got NOTIFY response') for origin, ipaddr, trynum, senttime in self.sentnotify: if ipaddr == addr[0] and msg.question.qname == origin: self.sentnotify.remove((origin, ipaddr, trynum, senttime)) else: log(0,'got NOTIFY') self.process_notify(msg, addr[0], addr[1]) elif not msg.header.qr and msg.header.opcode == 0: # it's a question qname = msg.question.qname.lower() log(2,'GOT QUERY for ' + qname + '(' + msg.question.qtype + ') from ' + addr[0]) # handle special version packet if (msg.question.qtype == 'TXT' and msg.question.qclass == 'CH'): if qname == 'version.bind': server.sendpackets([getversion(qname, msg.header.id, msg.header.rd, dorecursion, '1.0')],addr) elif qname == 'version.oak': server.sendpackets([getversion(qname, msg.header.id, msg.header.rd, dorecursion, '1.0')],addr) return self.zdb.lookup(zkeys, msg, addr, server, dorecursion, flist, self.lookup_callback) def lookup_callback(self, msg, addr, server, dorecursion, flist, answerlist): if answerlist: server.sendpackets(self.config.outpackets(answerlist), addr) elif dorecursion: if msg.question.qtype in ['AXFR','IXFR']: if msg.question.qname == 'cache' and msg.question.qtype == 'AXFR': if self.resolver: server.sendpackets(self.resolver.cache.zonetrans(msg.header.id),addr) else: # won't forward zone transfers and # don't handle recursive zone transfers server.sendpackets([self.error(msg.header.id, msg.question.qname, msg.question.qtype, msg.question.qclass,2)],addr) else: self.resolver.handle_query(msg, addr, flist, server.sendpackets) def reap(self): log(4,'in nameserver reap') # do all maintenence (interval) stuff here if self.resolver: self.resolver.reap() self.notify() curtime = time.time() if curtime > (self.last_reap_time + self.maint_int): self.last_reap_time = curtime # do zone transfers here if slave server and haven't asked for soa for (zkey, origin, masterip) in self.zdb.getslaves(curtime): if not self.askedsoa.has_key(zkey): self.need_zonetransfer(zkey, origin, masterip) for zkey in self.askedsoa.keys(): if curtime > self.askedsoa[zkey]['senttime'] + self.soatimeout: if self.askedsoa[zkey]['trynum'] > 3: self.zdb.remove_zone(zkey) del self.askedsoa[zkey] else: masterip = self.askedsoa[zkey]['masterip'] origin = self.askedsoa[zkey]['origin'] trynum = self.askedsoa[zkey]['trynum'] del self.askedsoa[zkey] self.need_zonetransfer(zkey, origin, masterip, trynum) def log_info (self, message, type='info'): if __debug__ or type != 'info': log(0,'%s: %s' % (type, message)) class resolver(asyncore.dispatcher): def __init__(self, cache, port=0): asyncore.dispatcher.__init__(self) self.create_socket(socket.AF_INET, socket.SOCK_DGRAM) self.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 200 * 1024) self.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 200 * 1024) self.bind(('',port)) self.cache = cache self.outqnum = 0 self.outq = {} self.holdq = {} self.holdtime = 10 self.holdqlength = 100 self.last_reap_time = time.time() self.maint_int = 10 self.timeout = 3 def getoutqkey(self): self.outqnum = self.outqnum + 1 if self.outqnum == 99999: self.outqnum = 1 return str(self.outqnum) def error(self, id, qname, querytype, queryclass, rcode): error = message() error.header.id = id error.header.rcode = rcode error.header.qr = 1 error.question.qname = qname error.question.qtype = querytype error.question.qclass = queryclass return error def qpacket(self, id, qname, querytype, queryclass): # create a question query = message() query.header.id = id query.header.rd = 0 query.question.qname = qname query.question.qtype = querytype query.question.qclass = queryclass return query def send_to(self, msglist, addr): for msg in msglist: data = msg.buildpkt() if len(data) > 512: # packet to big msg.header.tc = 1 msg.header.ancount = 0 msg.answerlist = [] msg.header.nscount = 0 msg.authlist = [] msg.header.arcount = 0 msg.addlist = [] self.socket.sendto(msg.buildpkt(), addr) else: self.socket.sendto(data, addr) def handle_read(self): try: while 1: msgdata, addr = self.socket.recvfrom(1500) # should put 'try' here in production server self.handle_packet(msgdata, addr) except socket.error, why: if why[0] != asyncore.EWOULDBLOCK: raise socket.error, why def handle_packet(self, msgdata, addr): try: msg = message(msgdata) except: return if not msg.header.qr: self.handle_query(msg, addr, [], self.send_to) else: log(2,'received unsolicited reply') def handle_query(self, msg, addr, flist, cbfunc): qname = msg.question.qname querytype = msg.question.qtype queryclass = msg.question.qclass # check the cache first answer = self.cache.haskey(qname,querytype,msg) if answer: cbfunc([answer], addr) log(2,'sent answer for ' + qname + '(' + querytype + ') from cache') else: # check if query is already in progess for oqkey in self.outq.keys(): if (self.outq[oqkey]['qname'] == qname and self.outq[oqkey]['querytype'] == querytype): log(2,'query already in progress for '+qname+'('+querytype+')') # put entry in hold queue to try later hqrec = {'processtime':time.time()+self.holdtime, 'query':msg,'addr':addr, 'qname':qname,'querytype':querytype, 'queryclass':queryclass, 'cbfunc':cbfunc} self.putonhold(hqrec) return outqkey = self.getoutqkey()+str(msg.header.id) self.outq[outqkey] = {'query':msg, 'addr':addr, 'qname':qname, 'querytype':querytype, 'queryclass':queryclass, 'cbfunc':cbfunc, 'answerlist':[], 'addlist':[], 'qsent':0} if flist: self.outq[outqkey]['flist'] = flist self.askfns(outqkey) else: self.askns(outqkey) def putonhold(self,hqrec): hqid = hqrec['qname']+hqrec['querytype'] if self.holdq.has_key(hqid): if len(self.holdq[hqid]) < self.holdqlength: hqrec['processtime']=time.time()+self.holdtime self.holdq[hqid].append(hqrec) def askns(self, outqkey): qname = self.outq[outqkey]['qname'] querytype = self.outq[outqkey]['querytype'] queryclass = self.outq[outqkey]['queryclass'] # don't try more than 10 times to avoid loops if self.outq[outqkey]['qsent'] == 10: del self.outq[outqkey] log(2,'Dropping query for ' + qname + '(' + querytype + ')' + ' POSSIBLE LOOP') return # find the best nameservers to ask from the cache (qzone, nsdict) = self.cache.getnslist(qname) if not nsdict: # there are no good servers if self.outq[outqkey]['addr'] != 'IQ': qid = self.outq[outqkey]['query'].header.id self.outq[outqkey]['cbfunc'](self.error(qid,qname,querytype,queryclass,2), self.outq[outqkey]['addr']) del self.outq[outqkey] log(2,'Dropping query for ' + qname + '(' + querytype + ')' + 'no good name servers to ask') return # pick the best nameserver rtts = nsdict.keys() rtts.sort() bestnsip = nsdict[rtts[0]]['ip'] bestnsname = nsdict[rtts[0]]['name'] # fill in the callback data structure id=random.randrange(1,32768) self.outq[outqkey]['nsqueriedlastip'] = bestnsip self.outq[outqkey]['nsqueriedlastname'] = bestnsname self.outq[outqkey]['nsdict'] = nsdict self.outq[outqkey]['qzone'] = qzone self.outq[outqkey]['qsenttime'] = time.time() self.outq[outqkey]['qsent'] = self.outq[outqkey]['qsent'] + 1 # self.socket.sendto(self.qpacket(id,qname,querytype,queryclass), (bestnsip,53)) self.outq[outqkey]['request'] = simpleudprequest(self.qpacket(id,qname,querytype,queryclass), self.handle_response, bestnsip, outqkey) # update rtt so that we ask a different server next time self.cache.updatertt(bestnsname,qzone,1) log(2,outqkey+'|sent query to ' + bestnsip + '(' + bestnsname + ') for ' + qname + '(' + querytype + ')') def askfns(self, outqkey): flist = self.outq[outqkey]['flist'] qname = self.outq[outqkey]['qname'] querytype = self.outq[outqkey]['querytype'] queryclass = self.outq[outqkey]['queryclass'] self.outq[outqkey]['qsenttime'] = time.time() id=random.randrange(1,32768) # self.socket.sendto(self.qpacket(id,qname,querytype,queryclass), (flist[0],53)) self.outq[outqkey]['request'] = simpleudprequest(self.qpacket(id,qname,querytype,queryclass), self.handle_fresponse, flist[0], outqkey) log(2,''+outqkey+'|sent query to forwarder') def handle_response(self, msg, outqkey): # either reponse: # 1. contains a name error # 2. answers the question # (cache data and return it) # 3. is (contains) a CNAME and qtype isn't # (cache cname and change qname to it) # (check if qname and qtype are in any other rrs in the response) # (must check cache again here) # 4. contains a better delegation # (cache the delegation and start again) # 5. is aserver failure # (delete server from list and try again) # make sure that original question is still outstanding if not self.outq.has_key(outqkey): # should never get here # if we do we aren't doing housekeeping of callbacks very well log(2,''+outqkey+'|got response for a question already answered for ' + msg.question.qname) return querytype = self.outq[outqkey]['querytype'] if msg.header.rcode not in [1,2,4,5]: # update rtt time rtt = time.time() - self.outq[outqkey]['qsenttime'] nsname = self.outq[outqkey]['nsqueriedlastname'] zone = self.outq[outqkey]['qzone'] self.cache.updatertt(nsname,zone,rtt) if msg.header.rcode == 3: log(2,outqkey+'|GOT Name Error for ' + msg.question.qname + '(' + msg.question.qtype + ')') # name error # cache negative answer self.cache.addneg(self.outq[outqkey]['qname'], self.outq[outqkey]['querytype'], self.outq[outqkey]['queryclass']) if self.outq[outqkey]['addr'] != 'IQ': answer = message() answer.question.qname = self.outq[outqkey]['query'].question.qname answer.question.qtype = self.outq[outqkey]['query'].question.qtype answer.question.qclass = self.outq[outqkey]['query'].question.qclass answer.header.id = self.outq[outqkey]['query'].header.id answer.header.qr = 1 answer.header.opcode = self.outq[outqkey]['query'].header.opcode answer.header.ra = 1 self.outq[outqkey]['cbfunc']([answer], self.outq[outqkey]['addr']) del self.outq[outqkey] elif msg.header.ancount > 0: # answer (may be CNAME) haveanswer = 0 cname = '' log(2,'CACHING ANSWERLIST ENTRIES') for rr in msg.answerlist: rrname = rr.keys()[0] rrtype = rr[rrname].keys()[0] if ((rrname == msg.question.qname or rrname == cname ) and rrtype == msg.question.qtype): haveanswer = 1 if rrname == msg.question.qname and rrtype == 'CNAME': cname = rr[rrname][rrtype][0]['cname'] self.cache.add(rr, self.outq[outqkey]['qzone'], self.outq[outqkey]['nsqueriedlastname']) if haveanswer: if self.outq[outqkey]['addr'] != 'IQ': log(2,''+outqkey+'|GOT Answer for ' + msg.question.qname + '(' + msg.question.qtype + ')' ) answer = message() answer.answerlist = msg.answerlist + self.outq[outqkey]['answerlist'] answer.header.ancount = len(answer.answerlist) answer.question.qname = self.outq[outqkey]['query'].question.qname answer.question.qtype = self.outq[outqkey]['query'].question.qtype answer.question.qclass = self.outq[outqkey]['query'].question.qclass answer.header.id = self.outq[outqkey]['query'].header.id answer.header.qr = 1 answer.header.opcode = self.outq[outqkey]['query'].header.opcode answer.header.ra = 1 self.outq[outqkey]['cbfunc']([answer], self.outq[outqkey]['addr']) log(2,outqkey+'|sent answer retrieved from remote server for ' + self.outq[outqkey]['query'].question.qname) else: log(2,outqkey+'|GOT Answer(IQ) for ' + msg.question.qname + '(' + msg.question.qtype + ')') del self.outq[outqkey] elif cname: log(2,outqkey+'|GOT CNAME for ' + msg.question.qname + '(' + msg.question.qtype + ')') self.outq[outqkey]['answerlist'] = self.outq[outqkey]['answerlist'] + msg.answerlist self.outq[outqkey]['qname'] = cname self.askns(outqkey) else: log(2,outqkey+'|GOT BOGUS answer for ' + msg.question.qname + '(' + msg.question.qtype + ')') del self.outq[outqkey] elif msg.header.nscount > 0 and msg.header.ancount == 0: log(2,outqkey+'|GOT DELEGATION for ' + msg.question.qname + '(' + msg.question.qtype + ')') # delegation # cache the nameserver rrs and start over # if there are no glue records for nameservers must fetch them first log(2,'CACHING AUTHLIST ENTRIES') for rr in msg.authlist: self.cache.add(rr,self.outq[outqkey]['qzone'],self.outq[outqkey]['nsqueriedlastname']) log(2,'CACHING ADDLIST ENTRIES') for rr in msg.addlist: self.cache.add(rr,self.outq[outqkey]['qzone'],self.outq[outqkey]['nsqueriedlastname']) rrlist = msg.authlist+msg.addlist fetchglue = 0 nscount = 0 for rr in msg.authlist: nodename = rr.keys()[0] if rr[nodename].keys()[0] == 'NS': nscount = nscount + 1 nsdname = rr[nodename]['NS'][0]['nsdname'] if not self.cache.haskey(nsdname,'A'): log(2,outqkey+'|Glue record not in cache for ' + nsdname + '(A)') fetchglue = fetchglue + 1 # need to fetch A rec noutqkey = self.getoutqkey()+str(random.randrange(1,32768)) self.outq[noutqkey] = {'query':'', 'addr':'IQ', 'qname':nsdname, 'querytype':'A', 'queryclass':'IN', 'qsent':0} log(2,outqkey+'|sending a query to fetch glue records for ' + nsdname + '(A)') self.askns(noutqkey) if not nscount: log(2,outqkey+'|Dropping query (no ns recs) for ' + msg.question.qname + '(' + msg.question.qtype + ')' ) del self.outq[outqkey] elif fetchglue == nscount: log(2,outqkey+'|Stalling query (no glue recs) for ' + msg.question.qname + '(' + msg.question.qtype + ')') self.putonhold(self.outq[outqkey]) del self.outq[outqkey] else: log(2,outqkey+'|got (some) glue with delegation') self.askns(outqkey) elif msg.header.rcode in [1,2,4,5]: log(2,outqkey+'|GOT ' + getrcode(msg.header.rcode)) log(2,'SERVER ' + self.outq[outqkey]['nsqueriedlastname'] + '(' + self.outq[outqkey]['nsqueriedlastip'] + ') FAILURE for ' + msg.question.qname) # don't ask this server for a while self.cache.badns(self.outq[outqkey]['qzone'],self.outq[outqkey]['nsqueriedlastname']) self.askns(outqkey) else: log(2,outqkey+'|GOT UNPARSEABLE REPLY') msg.printpkt() def handle_fresponse(self, msg, outqkey): if msg.header.rcode in [1,2,4,5]: self.outq[outqkey]['flist'].pop(0) if len(self.outq[outqkey]['flist']) == 0: qid = self.outq[outqkey]['query'].header.id qname = self.outq[outqkey]['qname'] querytype = self.outq[outqkey]['querytype'] queryclass = self.outq[outqkey]['queryclass'] self.outq[outqkey]['cbfunc'](self.error(qid,qname,querytype,queryclass,2), self.outq[outqkey]['addr']) del self.outq[outqkey] else: self.askfns(outqkey) else: answer = message() answer.header.id = self.outq[outqkey]['query'].header.id answer.header.qr = 1 answer.header.opcode = self.outq[outqkey]['query'].header.opcode answer.header.ra = 1 answer.question.qname = self.outq[outqkey]['query'].question.qname answer.question.qtype = self.outq[outqkey]['query'].question.qtype answer.question.qclass = self.outq[outqkey]['query'].question.qclass answer.header.ancount = msg.header.ancount answer.header.nscount = msg.header.nscount answer.header.arcount = msg.header.arcount answer.answerlist = msg.answerlist answer.authlist = msg.authlist answer.addlist = msg.addlist if msg.header.rcode == 3: # name error # cache negative answer self.cache.addneg(self.outq[outqkey]['qname'], self.outq[outqkey]['querytype'], self.outq[outqkey]['queryclass']) else: # cache all rrs for rr in msg.answerlist: self.cache.add(rr,'','forwarder') for rr in msg.authlist: self.cache.add(rr,'','forwarder') for rr in msg.addlist: self.cache.add(rr,'','forwarder') self.outq[outqkey]['cbfunc']([answer], self.outq[outqkey]['addr']) del self.outq[outqkey] def writable(self): return 0 def handle_write(self): pass def handle_connect(self): pass def handle_close(self): # print '1:In handle close' return def process_holdq(self): curtime = time.time() for hqkey in self.holdq.keys(): for hqrec in self.holdq[hqkey]: if curtime >= hqrec['processtime']: log(2,'processing held query') answer = self.cache.haskey(hqrec['qname'], hqrec['querytype'], hqrec['query']) if answer: hqrec['cbfunc']([answer], hqrec['addr']) log(2,'sent answer for ' + hqrec['qname'] + '(' + hqrec['querytype'] + ') from cache') self.holdq[hqkey].remove(hqrec) if len(self.holdq[hqkey]) == 0: del self.holdq[hqkey] def reap(self): self.process_holdq() curtime = time.time() log(3,timestamp() + 'processed HOLDQ (sockets: ' + str(len(asyncore.socket_map.keys()))+')') if curtime > (self.last_reap_time + self.maint_int): self.last_reap_time = curtime for outqkey in self.outq.keys(): if curtime > self.outq[outqkey]['qsenttime'] + self.timeout: log(2,'query for '+self.outq[outqkey]['qname']+'('+ self.outq[outqkey]['querytype']+') expired') # don't set forwarders as bad if not self.outq[outqkey].has_key('flist'): self.cache.badns(self.outq[outqkey]['qzone'], self.outq[outqkey]['nsqueriedlastname']) if self.outq[outqkey].has_key('request'): log(3,'closing socket for expired query') self.outq[outqkey]['request'].close() del self.outq[outqkey] return def log_info (self, message, type='info'): if __debug__ or type != 'info': log(0,'%s: %s' % (type, message)) def run(configobj): global loglevel r = resolver(dnscache(configobj.cached)) ns = nameserver(r, configobj) udpds = udpdnsserver(53, ns) tcpds = tcpdnsserver(53, ns) loglevel = configobj.loglevel try: loop(ns.reap) except KeyboardInterrupt: print 'server done' if __name__ == '__main__': sipb_xen_database.connect('postgres://sipb-xen@sipb-xen-dev/sipb_xen') zonedict = {'example.net':{'origin':'example.net', 'filename':'db.example.net', 'type':'master', 'slaves':[]}} zonedict = {'servers.csail.mit.edu':{'origin':'servers.csail.mit.edu', 'filename':'db.servers.csail.mit.edu', 'type':'master', 'slaves':[]}} zonedict2 = {'example.net':{'origin':'example.net', 'filename':'db.example.net', 'type':'slave', 'masterip':'127.0.0.1'}} readzonefiles(zonedict) lconfig = dnsconfig() lconfig.zonedatabase = zonedb(zonedict) pr = zonefileparser() pr.parse('','db.ca') lconfig.cached = pr.getzdict() lconfig.loglevel = 3 run(lconfig)