source: trunk/dns/nameserver.py @ 172

Last change on this file since 172 was 131, checked in by ecprice, 17 years ago

DNS (*.servers.csail.mit.edu from database)

  • Property svn:executable set to *
File size: 118.6 KB
Line 
1#!/usr/bin/python
2#    Python Domain Name Server
3#    Copyright (C) 2002  Digital Lumber, Inc.
4
5#    This library is free software; you can redistribute it and/or
6#    modify it under the terms of the GNU Lesser General Public
7#    License as published by the Free Software Foundation; either
8#    version 2.1 of the License, or (at your option) any later version.
9
10#    This library is distributed in the hope that it will be useful,
11#    but WITHOUT ANY WARRANTY; without even the implied warranty of
12#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
13#    Lesser General Public License for more details.
14
15#    You should have received a copy of the GNU Lesser General Public
16#    License along with this library; if not, write to the Free Software
17#    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
18
19import socket
20import asyncore
21import asynchat
22import select
23import types
24import random
25import time
26import signal
27import string
28import sys
29import sipb_xen_database
30from errno import EALREADY, EINPROGRESS, EWOULDBLOCK, ECONNRESET, \
31     ENOTCONN, ESHUTDOWN, EINTR, EISCONN, ETIMEDOUT
32
33# EXAMPLE ZONE FILE DATA STRUCTURE
34
35# NOTE:
36# There are no trailing dots in the internal data
37# structure.  Although it's hard to tell by reading
38# the RFC's, the dots on the end of names are just
39# used internally by the resolvers and servers to
40# see if they need to append a domain name onto
41# the end of names.  There are no trailing dots
42# on names in queries on the network.
43
44examplenet = {'example.net':{'SOA':[{'class':'IN',
45                                     'ttl':10,
46                                     'mname':'ns1.example.net',
47                                     'rname':'hostmaster.example.net',
48                                     'serial':1,
49                                     'refresh':10800,
50                                     'retry':3600,
51                                     'expire':604800,
52                                     'minimum':3600}],
53                             'NS':[{'class':'IN',
54                                    'ttl':10,
55                                    'nsdname':'ns1.example.net'},
56                                   {'ttl':10,
57                                    'nsdname':'ns2.example.net'}],
58                             'MX':[{'class':'IN',
59                                    'ttl':10,
60                                    'preference':10,
61                                    'exchange':'mail.example.net'}]},
62              'server1.example.net':{'A':[{'class':'IN',
63                                           'ttl':10,
64                                           'address':'10.1.2.3'}]},
65              'www.example.net':{'CNAME':[{'class':'IN',
66                                           'ttl':10,
67                                           'cname':'server1.example.net'}]},
68              'router.example.net':{'A':[{'class':'IN',
69                                          'ttl':10,
70                                          'address':'10.1.2.1'},
71                                         {'class':'IN',
72                                          'ttl':10,
73                                          'address':'10.2.1.1'}]}
74             
75              }
76
77# setup logging defaults
78loglevel = 0
79logfile = sys.stdout
80
81try:
82    file
83except NameError:
84    def file(name, mode='r', buffer=0):
85        return open(name, mode, buffer)
86
87def log(level,msg):
88    if level <= loglevel:
89        logfile.write(msg+'\n')
90
91def timestamp():
92    return time.strftime('%m/%d/%y %H:%M:%S')+ '-'
93
94def inttoasc(number):
95    try:
96        hs = hex(number)[2:]
97    except:
98        log(0,'inttoasc cannot convert ' + repr(number))
99    if hs[-1:].upper() == 'L':
100        hs = hs[:-1]
101    result = ''
102    while len(hs) > 2:
103        result = chr(int(hs[-2:],16)) + result
104        hs = hs[:-2]
105    result = chr(int(hs,16)) + result
106   
107    return result
108   
109def asctoint(ascnum):
110    rascnum = ''
111    for i in range(len(ascnum)-1,-1,-1):
112        rascnum = rascnum + ascnum[i]
113    result = 0
114    count = 0
115    for c in rascnum:
116        x = ord(c) << (8*count)
117        result = result + x
118        count = count + 1
119
120    return result
121
122def ipv6net_aton(ip_string):
123    packed_ip = ''
124    # first account for shorthand syntax
125    pieces = ip_string.split(':')
126    pcount = 0
127    for part in pieces:
128        if part != '':
129            pcount = pcount + 1
130    if pcount < 8:
131        rs = '0:'*(8-pcount)
132        ip_string = ip_string.replace('::',':'+rs)
133    if ip_string[0] == ':':
134        ip_string = ip_string[1:]
135    pieces = ip_string.split(':')
136    for part in pieces:
137        # pad with the zeros
138        i = 4-len(part)
139        part = i*'0'+part
140        packed_ip = packed_ip +  chr(int(part[:2],16))+ chr(int(part[2:],16))
141    return packed_ip
142
143def ipv6net_ntoa(packed_ip):
144    ip_string = ''
145    count = 0
146    for c in packed_ip:
147        ip_string = ip_string + hex(ord(c))[2:]
148        count = count + 1
149        if count == 2:
150            ip_string = ip_string + ':'
151            count = 0
152    return ip_string[:-1]   
153
154def getversion(qname, id, rd, ra, versionstr):
155    msg = message()
156    msg.header.id = id
157    msg.header.qr = 1
158    msg.header.aa = 1
159    msg.header.rd = rd
160    msg.header.ra = ra
161    msg.header.rcode = 0
162    msg.question.qname = qname
163    msg.question.qtype = 'TXT'
164    msg.question.qclass = 'CH'
165    if qname == 'version.bind':
166        msg.header.ancount = 2
167        msg.answerlist.append({qname:{'CNAME':[{'cname':'version.oak',
168                                                'ttl':360000,
169                                                'class':'CH'}]}})
170        msg.answerlist.append({'version.oak':{'TXT':[{'txtdata':versionstr,
171                                                      'ttl':360000,
172                                                      'class':'CH'}]}})
173    else:
174        msg.header.ancount = 1
175        msg.answerlist.append({qname:{'TXT':[{'txtdata':versionstr,
176                                              'ttl':360000,
177                                              'class':'CH'}]}})
178    return msg
179
180def getrcode(rcode):
181    if rcode == 0:
182        rcodestr = 'NOERROR(No error condition)'
183    elif rcode == 1:
184        rcodestr = 'FORMERR(Format Error)'
185    elif rcode == 2:
186        rcodestr = 'SERVFAIL(Internal failure)'
187    elif rcode == 3:
188        rcodestr = 'NXDOMAIN(Name does not exist)'
189    elif rcode == 4:
190        rcodestr = 'NOTIMP(Not Implemented)'
191    elif rcode == 5:
192        rcodestr = 'REFUSED(Security violation)'
193    elif rcode == 6:
194        rcodestr = 'YXDOMAIN(Name exists)'
195    elif rcode == 7:
196        rcodestr = 'YXRRSET(RR exists)'
197    elif rcode == 8:
198        rcodestr = 'NXRRSET(RR does not exist)'
199    elif rcode == 9:
200        rcodestr = 'NOTAUTH(Server not Authoritative)'
201    elif rcode == 10:
202        rcodestr = 'NOTZONE(Name not in zone)'
203    else:
204        rcodestr = 'Unknown RCODE(' + str(rcode) + ')'
205    return rcodestr
206
207def printrdata(dnstype, rdata):
208    if dnstype == 'A':
209        return rdata['address']
210    elif dnstype == 'MX':
211        return str(rdata['preference'])+'\t'+rdata['exchange']+'.'
212    elif dnstype == 'NS':
213        return rdata['nsdname']+'.'
214    elif dnstype == 'PTR':
215        return rdata['ptrdname']+'.'
216    elif dnstype == 'CNAME':
217        return rdata['cname']+'.'
218    elif dnstype == 'SOA':
219        return (rdata['mname']+'.\t'+rdata['rname']+'. (\n'+35*' '+str(rdata['serial'])+'\n'+
220                35*' '+str(rdata['refresh'])+'\n'+35*' '+str(rdata['retry'])+'\n'+35*' '+
221                str(rdata['expire'])+'\n'+35*' '+str(rdata['minimum'])+' )')
222
223def makezonedatalist(zonedata, origin):
224    # unravel structure into list
225    zonedatalist = []
226    # get soa first
227    soanode = zonedata[origin]
228    zonedatalist.append([origin+'.','SOA',soanode['SOA'][0]])
229    for item in soanode.keys():
230        if item != 'SOA':
231            for listitem in soanode[item]:
232                zonedatalist.append([origin+'.', item, listitem])
233    for nodename in zonedata.keys():
234        if nodename != origin:
235            for item in zonedata[nodename].keys():
236                for listitem in zonedata[nodename][item]:
237                    zonedatalist.append([nodename+'.', item, listitem])
238    return zonedatalist
239
240def writezonefile(zonedata, origin, file):
241    zonedatalist = makezonedatalist(zonedata, origin)
242    for rr in zonedatalist:
243        owner = rr[0]
244        dnstype = rr[1]
245        line = (owner + (35-len(owner))*' ' + str(rr[2]['ttl']) + '\t\tIN\t' +
246                dnstype + '\t' + printrdata(dnstype, rr[2]))
247        file.write(line + '\n')
248
249def readzonefiles(zonedict):
250    for k in zonedict.keys():
251        filepath = zonedict[k]['filename']
252        try:
253            pr = zonefileparser()
254            pr.parse(zonedict[k]['origin'],filepath)
255            zonedict[k]['zonedata'] = pr.getzdict()
256        except ZonefileError, lineno:
257            log(0,'Error reading zone file ' + filepath  + ' at line ' +
258                str(lineno) + '\n')
259            del zonedict[k]
260
261def slowloop(tofunc='',timeout=5.0):
262    if not tofunc:
263        def tofunc(): return
264    map = asyncore.socket_map
265    while map:
266        r = []; w=[]; e=[]
267        for fd, obj in map.items():
268            if obj.readable():
269                r.append(fd)
270            if obj.writable():
271                w.append(fd)
272        try:
273            starttime = time.time()           
274            r,w,e = select.select(r,w,e,timeout)
275            endtime = time.time()
276            if endtime-starttime >= timeout:
277                tofunc()
278        except select.error, err:
279            if err[0] != EINTR:
280                raise
281            r=[]; w=[]; e=[]
282            log(0,'ERROR in select')
283
284        for fd in r:
285            try:
286                obj=map[fd]
287            except KeyError:
288                log(0,'KeyError in socket map')               
289                continue
290            try:
291                obj.handle_read_event()
292            except:
293                log(0,'calling HANDLE ERROR from loop')
294                log(0,repr(obj))
295                obj.handle_error()
296        for fd in w:
297            try:
298                obj=map[fd]
299            except KeyError:
300                log(0,'KeyError in socket map')               
301                continue
302            try:
303                obj.handle_read_event()
304            except:
305                log(0,'calling HANDLE ERROR from loop')
306                log(0,repr(obj))               
307                obj.handle_error()
308
309def fastloop(tofunc='',timeout=5.0):
310    if not tofunc:
311        def tofunc(): return
312    polltimeout = timeout*1000
313    map = asyncore.socket_map
314    while map:
315        regfds = 0
316        pollobj = select.poll()
317        for fd, obj in map.items():
318            flags = 0
319            if obj.readable():
320                flags = select.POLLIN
321            if obj.writable():
322                flags = flags | select.POLLOUT
323            if flags:
324                pollobj.register(fd, flags)
325                regfds = regfds + 1
326        try:
327            starttime = time.time()
328            r = pollobj.poll(polltimeout)
329            endtime = time.time()
330            if endtime-starttime >= timeout:
331                tofunc()
332        except select.error, err:
333            if err[0] != EINTR:
334                raise
335            r = []
336            log(0,'ERROR in select')
337        for fd, flags in r:
338            try:
339                obj = map[fd]
340                badvals = (select.POLLPRI + select.POLLERR +
341                           select.POLLHUP + select.POLLNVAL)
342                if (flags & badvals):
343                    if (flags & select.POLLPRI):
344                        log(0,'POLLPRI')
345                    if (flags & select.POLLERR):
346                        log(0,'POLLERR')
347                    if (flags & select.POLLHUP):
348                        log(0,'POLLHUP')
349                    if (flags & select.POLLNVAL):
350                        log(0,'POLLNVAL')
351                    obj.handle_error()
352                else:
353                    if (flags  & select.POLLIN):
354                        obj.handle_read_event()
355                    if (flags & select.POLLOUT):
356                        obj.handle_write_event()
357            except KeyError:
358                log(0,'KeyError in socket map')
359                continue
360            except:
361                # print traceback
362                sf = StringIO.StringIO()
363                traceback.print_exc(file=sf)
364                log(0,'ERROR IN LOOP:')
365                log(0,sf.getvalue())
366                sf.close()
367                log(0,repr(obj))
368                obj.handle_error()
369
370if hasattr(select,'poll'):
371    loop = fastloop
372else:
373    loop = slowloop
374
375class ZonefileError(Exception):
376    def __init__(self, linenum, errordesc=''):
377        self.linenum = linenum
378        self.errordesc = errordesc
379    def __str__(self):
380        return str(self.linenum) + ' (' + self.errordesc + ')'
381
382class zonefileparser:
383    def __init__(self):
384        self.zonedata = {}
385        self.dnstypes = ['A','AAAA','CNAME','HINFO','LOC','MX',
386                         'NS','PTR','RP','SOA','SRV','TXT']
387       
388    def stripcomments(self, line):
389        i = line.find(';')
390        if i >= 0:
391            line = line[:i]
392        return line
393
394    def strip(self, line):
395        # strip trailing linefeeds
396        if line[-1:] == '\n':
397            line = line[:-1]
398        return line
399
400    def getzdict(self):
401        return self.zonedata
402
403    def addorigin(self, origin, name):
404        if name[-1:] != '.':
405            return name + '.' + origin
406        else:
407            return name[:-1]
408
409    def getstrings(self, s):
410        if s.find('"') == -1:
411            return s.split()
412        else:
413            x = s.split('"')
414            rlist = []
415            for i in x:
416                if i != '' and i != ' ':
417                    rlist.append(i)
418            return rlist
419
420    def getlocsize(self, s):
421        if s[-1:] == 'm':
422            size = float(s[:-1])*100
423        else:
424            size = float(s)*100
425        i = 0
426        while size > 9:
427            size = size/10
428            i = i + 1
429        return (int(size),i)
430
431    def getloclat(self, l,c):
432        deg = float(l[0])
433        min = 0
434        secs = 0
435        if len(l) == 3:
436            min = float(l[1])
437            secs = float(l[2])
438        elif len(l) == 2:
439            min = float(l[1])
440        rval = ((((deg *60) + min) * 60) + secs) * 1000
441        if c in ['N','E']:
442            rval = rval + (2**31)
443        elif c in ['S','W']:
444            rval = (2**31) - rval
445        else:
446            log(0,'ERROR: unsupported latitude/longitude direction')
447        return long(rval)
448
449    def getgname(self, name, iter):
450        if name == '0' or name == 'O':
451            return ''
452        start = 0
453        offset = 0
454        width = 0
455        base = 'd'
456        for x in range(name.count('$')):
457            i = name.find('$',start)
458            j = i
459            start = i+1
460            if i>0:
461                if name[i-1] == '\\':
462                    continue
463            if len(name)>i+1:
464                if name[i+1] == '$':
465                    continue
466                if name[i+1] == '{':
467                    j = name.find('}',i+1)
468                    owb = name[i+2:j].split(',')
469                    if len(owb) == 1:
470                        offset = int(owb[0])
471                    elif len(owb) == 2:
472                        offset = int(owb[0])
473                        width = int(owb[1])
474                    elif len(owb) == 3:
475                        offset = int(owb[0])
476                        width = int(owb[1])
477                        base = owb[2]
478            val = iter - offset
479            if base == 'd':
480                rs = str(val)
481            elif base == 'o':
482                rs = oct(val)
483            elif base == 'x':
484                rs = hex(val)[2:].lower()
485            elif base == 'X':
486                rs = hex(val)[2:].upper()
487            else:
488                rs = ''
489            if len(rs) > width:
490                rs = (width-len(rs))*'0'+rs
491            name = name[:i]+rs+name[j+1:]
492            start = i+len(rs)+1
493
494        return name
495
496    def getrrdata(self, origin, dnstype, dnsclass, ttl, tokens):
497        rdata = {}
498        rdata['class'] = dnsclass
499        rdata['ttl'] = ttl
500        if dnstype == 'A':
501            rdata['address'] = tokens[0]
502        elif dnstype == 'AAAA':
503            rdata['address'] = tokens[0]
504        elif dnstype == 'CNAME':
505            rdata['cname'] = self.addorigin(origin,tokens[0].lower())
506        elif dnstype == 'HINFO':
507            sl = self.getstrings(' '.join(tokens))
508            rdata['cpu'] = sl[0]
509            rdata['os'] = sl[1]
510        elif dnstype == 'LOC':
511            if 'N' in tokens:
512                i = tokens.index('N')
513            else:
514                i = tokens.index('S')
515            lat = self.getloclat(tokens[0:i],tokens[i])           
516            if 'E' in tokens:
517                j = tokens.index('E')
518            else:
519                j = tokens.index('W')
520            lng = self.getloclat(tokens[i+1:j],tokens[j])
521            size = self.getlocsize('1m')
522            horiz_pre = self.getlocsize('1000m')
523            vert_pre = self.getlocsize('10m')
524            if len(tokens[j+1:]) == 2:
525                size = self.getlocsize(tokens[-1:][0])
526            elif len(tokens[j+1:]) == 3:
527                size = self.getlocsize(tokens[-2:-1][0])
528                horiz_pre = self.getlocsize(tokens[-1:][0])
529            elif len(tokens[j+1:]) == 4:
530                size = self.getlocsize(tokens[-3:-2][0])
531                horiz_pre = self.getlocsize(tokens[-2:-1][0])
532                vert_pre = self.getlocsize(tokens[-1:][0])
533            if tokens[j+1][-1:] == 'm':
534                alt = int((float(tokens[j+1][:-1])*100)+10000000)
535            else:
536                size = int((float(tokens[j+1])*100)+10000000)
537            rdata['version'] = 0
538            rdata['size'] = size
539            rdata['horiz_pre'] = horiz_pre
540            rdata['vert_pre'] = vert_pre
541            rdata['latitude'] = lat
542            rdata['longitude'] = lng
543            rdata['altitude'] = 0
544        elif dnstype == 'MX':
545            rdata['preference'] = int(tokens[0])
546            rdata['exchange'] = self.addorigin(origin,tokens[1].lower())
547        elif dnstype == 'NS':
548            rdata['nsdname'] = self.addorigin(origin,tokens[0].lower())
549        elif dnstype == 'PTR':
550            rdata['ptrdname'] = self.addorigin(origin,tokens[0].lower())
551        elif dnstype == 'RP':
552            rdata['mboxdname'] = self.addorigin(origin,tokens[0].lower())
553            rdata['txtdname'] = self.addorigin(origin,tokens[1].lower())
554        elif dnstype == 'SOA':
555            rdata['mname'] = self.addorigin(origin,tokens[0].lower())
556            rdata['rname'] = self.addorigin(origin,tokens[1].lower())
557            rdata['serial'] = int(tokens[2])
558            rdata['refresh'] = int(tokens[3])
559            rdata['retry'] = int(tokens[4])
560            rdata['expire'] = int(tokens[5])
561            rdata['minimum'] = int(tokens[6])
562        elif dnstype == 'SRV':
563            rdata['priority'] = int(tokens[0])
564            rdata['weight'] = int(tokens[1])
565            rdata['port'] = int(tokens[2])
566            rdata['target'] = self.addorigin(origin,tokens[3].lower())
567        elif dnstype == 'TXT':
568            rdata['txtdata'] = self.getstrings(' '.join(tokens))[0]
569        else:
570            raise ZonefileError(lineno,'bad DNS type')           
571        return rdata
572
573    def addrec(self, owner, dnstype, rrdata):
574        if self.zonedata.has_key(owner):
575            if not self.zonedata[owner].has_key(dnstype):
576                self.zonedata[owner][dnstype] = []
577        else:
578            self.zonedata[owner] = {}
579            self.zonedata[owner][dnstype] = []
580        self.zonedata[owner][dnstype].append(rrdata)
581
582    def parse(self, origin, f):
583        closefile = 0
584        if type(f) != types.FileType:
585            # must be a path
586            try:
587                f = file(f)
588                closefile = 1
589            except:
590                log(0,'Invalid path to zonefile')
591                return
592        lastowner = ''
593        lastdnsclass = ''
594        lastttl = 3600
595        lineno = 0
596        while 1:
597            line = f.readline()
598            if not line:
599                break
600            lineno = lineno + 1
601            line = self.stripcomments(line)
602            line = self.strip(line)
603            if not line:
604                continue
605            if line.find('(') >= 0:
606                # grab lines until end paren
607                if line.find(')') == -1:
608                    line2 = self.stripcomments(f.readline())
609                    lineno = lineno + 1
610                    line2 = self.strip(line2)
611                    line = line + line2
612                    while line2.find(')') == -1:
613                        line2 = self.strip(self.stripcomments(f.readline()))
614                        lineno = lineno + 1
615                        line = line + line2
616                # now strip the parenthesis
617                line = line.replace(')','')
618                line = line.replace('(','')
619            # now line equals the entire RR entry
620            tokens = line.split()
621            if tokens[0].upper() == '$ORIGIN':
622                try:
623                    origin = tokens[1].lower()
624                except:
625                    raise ZonefileError(lineno, 'bad origin')
626            elif tokens[0].upper() == '$INCLUDE':
627                try:
628                    f2 = file(tokens[1].lower())
629                    if len(tokens) > 2:
630                        self.parse(tokens[2].lower(), f2)
631                    else:
632                        self.parse(origin, f2)
633                    f2.close()
634                except:
635                    raise ZonefileError(lineno, 'bad INCLUDE directive')
636            elif tokens[0].upper() == '$TTL':
637                try:
638                    lastttl = int(tokens[1])
639                except:
640                    raise ZonefileError(lineno, 'bad TTL directive')
641            elif tokens[0].upper() == '$GENERATE':
642                try:
643                    lhs = tokens[2].lower()
644                    dnstype = tokens[3].upper()
645                    rhs = tokens[4].lower()
646                    rng = tokens[1].split('-')                   
647                    start = int(rng[0])
648                    i = rng[1].find('/')
649                    if i != -1:
650                        stop = int(rng[1][:i])+1
651                        step = int(rng[1][i+1:])
652                    else:
653                        stop = int(rng[1])+1
654                        step = 1
655                    for i in range(start,stop,step):
656                        grhs = self.getgname(rhs,i)
657                        if dnstype in ['NS','CNAME','PTR']:
658                            grhs = self.addorigin(origin,grhs)
659                        rrdata = self.getrrdata(origin, dnstype, 'IN', lastttl,
660                                                [grhs])
661                        glhs = self.addorigin(origin,self.getgname(lhs,i))
662                        self.addrec(glhs,dnstype, rrdata)
663                except KeyError:
664                    raise ZonefileError(lineno, 'bad GENERATE directive')
665            else:
666                try:
667                    # if line begins with blank then owner is last owner
668                    if line[0] in string.whitespace:
669                        owner = lastowner
670                    else:
671                        owner = tokens[0].lower()
672                        tokens = tokens[1:]
673                        if owner == '@':
674                            owner = origin
675                        elif owner[-1:] != '.':
676                            owner = owner + '.' + origin
677                        else:
678                            owner = owner[:-1] # strip off trailing dot
679                    # line format is either: [class] [ttl] type RDATA
680                    #                     or [ttl] [class] type RDATA
681                    # - items in brackets are optional
682                    #
683                    # need to figure out which token is type
684                    # and backfill the missing data
685                    count = 0
686                    for token in tokens:
687                        if token.upper() in self.dnstypes:
688                            break
689                        count = count + 1
690                    # the following strips off the ttl and class if they exist
691                    if count == 0:
692                        ttl = lastttl
693                        dnsclass = lastdnsclass
694                    elif count == 1:
695                        if tokens[0].isdigit():
696                            ttl = int(tokens[0])
697                            dnsclass = lastdnsclass
698                        else:
699                            ttl = lastttl
700                            dnsclass = tokens[0].upper()
701                        tokens = tokens[1:]
702                    elif count == 2:
703                        if tokens[0].isdigit():
704                            ttl = int(tokens[0])
705                            dnsclass = tokens[1].upper()
706                        else:
707                            ttl = int(tokens[1])
708                            dnsclass = tokens[0].upper()
709                        tokens = tokens[2:]
710                    else:
711                        raise ZonefileError(lineno,'bad ttl or class')
712                    dnstype = tokens[0]
713                    # make sure all of the structure is there
714                    rrdata = self.getrrdata(origin, dnstype, dnsclass,
715                                            ttl, tokens[1:])
716                    self.addrec(owner, dnstype, rrdata)
717                    lastowner = owner
718                    lastttl = ttl
719                    lastdnsclass = dnsclass
720                except:
721                    raise ZonefileError(lineno,'unable to parse line')
722        if closefile:
723            f.close()
724       
725class dnsconfig:
726    def __init__(self):
727        # self.zonedb = zonedb({})
728        self.cached = {}
729        self.loglevel = 0
730       
731    def getview(self, msg, address, port):
732        # return:
733        #  1. a list of zone keys
734        #  2. whether or not to use the resolver
735        #     (i.e. answer recursive queries)
736        #  3. a list of forwarder addresses
737        return ['servers.csail.mit.edu'], 1, []
738
739    def allowupdate(self, msg, address, port):
740        # return 1 if updates are allowed
741        # NOTE: can only update the zones
742        #       returned by the getview func
743        return 1
744
745    def outpackets(self, packetlist):
746        # modify outgoing packets
747        return packetlist
748
749class dnsheader:
750    def __init__(self, id=1):
751        self.id = id # 16bit identifier generated by queryer
752        self.qr = 0 # one bit field specifying query(0) or response(1)
753        self.opcode = 0 # 4bit field specifying type of query
754        self.aa = 0 # authoritative answer
755        self.tc = 0 # message is not truncated
756        self.rd = 1 # recursion desired
757        self.ra = 0 # recursion available?
758        self.z = 0 # reserved for future use
759        self.rcode = 0 # response code (set in response)
760        self.qdcount = 1 # number of questions, only 1 is supported
761        self.ancount = 0 # number of rrs in the answer section
762        self.nscount = 0 # number of name server rrs in authority section
763        self.arcount = 0 # number or rrs in the additional section
764
765class dnsquestion:
766    def __init__(self):
767        self.qname = 'localhost'
768        self.qtype = 'A'
769        self.qclass = 'IN'
770
771class dnsupdatezone:
772    pass
773
774class message:
775    def __init__(self, msgdata=''):
776        if msgdata:
777            self.header = dnsheader()
778        else:
779            self.header = dnsheader(id=random.randrange(1,32768))
780        self.question = dnsquestion()
781        self.answerlist = []
782        self.authlist = []
783        self.addlist = []
784        self.u = ''
785        self.qtypes = {1:'A',2:'NS',3:'MD',4:'MF',5:'CNAME',6:'SOA',
786                       7:'MB',8:'MG',9:'MR',10:'NULL',11:'WKS',
787                       12:'PTR',13:'HINFO',14:'MINFO',15:'MX',
788                       16:'TXT',17:'RP',28:'AAAA',29:'LOC',33:'SRV',
789                       38:'A6',39:'DNAME',251:'IXFR',252:'AXFR',
790                       253:'MAILB',254:'MAILA',255:'ANY'}
791        self.rqtypes = {}
792        for key in self.qtypes.keys():
793            self.rqtypes[self.qtypes[key]] = key
794        self.qclasses = {1:'IN',2:'CS',3:'CH',4:'HS',254:'NONE',255:'ANY'}
795        self.rqclasses = {}
796        for key in self.qclasses.keys():
797            self.rqclasses[self.qclasses[key]] = key
798
799        if msgdata:
800            self.processpkt(msgdata)
801
802    def getdomainname(self, data, i):
803        log(4,'IN GETDOMAINNAME')
804        domainname = ''
805        gotpointer = 0
806        labellength= ord(data[i])
807        log(4,'labellength:' + str(labellength))
808        i = i + 1
809        while labellength != 0:
810            while labellength >= 192:
811                # pointer
812                if not gotpointer:
813                    rindex = i + 1
814                    gotpointer = 1
815                    log(4,'got pointer')
816                i = asctoint(chr(ord(data[i-1]) & 63)+data[i])
817                log(4,'new index:'+str(i))
818                labellength = ord(data[i])
819                log(4,'labellength:' + str(labellength))
820                i = i + 1
821            if domainname:
822                domainname = domainname + '.' + data[i:i+labellength]
823            else:
824                domainname = data[i:i+labellength]
825            log(4,'domainname:'+domainname)
826            i = i + labellength
827            labellength = ord(data[i])
828            log(4,'labellength:' + str(labellength))
829            i = i + 1
830        if not gotpointer:
831            rindex = i
832
833        return domainname.lower(), rindex
834
835    def getrrdata(self, type, msgdata, rdlength, i):
836        log(4,'unpacking RR data')
837        rdata = msgdata[i:i+rdlength]
838        if type == 'A':
839            return {'address':socket.inet_ntoa(rdata)}
840        elif type == 'AAAA':
841            return {'address':ipv6net_ntoa(rdata)}
842        elif type == 'CNAME':
843            cname, i = self.getdomainname(msgdata,i)
844            return {'cname':cname}
845        elif type == 'HINFO':
846            cpulen = ord(rdata[0])
847            cpu = rdata[1:cpulen+1]
848            return {'cpu':cpu,
849                    'os':rdata[cpulen+2:]}
850        elif type == 'LOC':
851            return {'version':ord(rdata[0]),
852                    'size':self.locsize(rdata[1]),
853                    'horiz_pre':self.locsize(rdata[2]),
854                    'vert_pre':self.locsize(rdata[3]),
855                    'latitude':asctoint(rdata[4:8]),
856                    'longitude':asctoint(rdata[8:12]),
857                    'altitude':asctoint(rdata[12:16])}
858        elif type == 'MX':
859            exchange, i = self.getdomainname(msgdata,i+2)
860            return {'preference':asctoint(rdata[:2]),
861                    'exchange':exchange}
862        elif type == 'NS':
863            nsdname, i = self.getdomainname(msgdata,i)
864            return {'nsdname':nsdname}
865        elif type == 'PTR':
866            ptrdname, i = self.getdomainname(msgdata,i)
867            return {'ptrdname':ptrdname}
868        elif type == 'RP':
869            mboxdname, i = self.getdomainname(msgdata,i)
870            txtdname, i = self.getdomainname(msgdata,i)
871            return {'mboxdname':mboxdname,
872                    'txtdname':txtdname}
873        elif type == 'SOA':
874            mname, i = self.getdomainname(msgdata,i)
875            rname, i = self.getdomainname(msgdata,i)
876            return {'mname':mname,
877                    'rname':rname,
878                    'serial':asctoint(msgdata[i:i+4]),
879                    'refresh':asctoint(msgdata[i+4:i+8]),
880                    'retry':asctoint(msgdata[i+8:i+12]),
881                    'expire':asctoint(msgdata[i+12:i+16]),
882                    'minimum':asctoint(msgdata[i+16:i+20])}
883        elif type == 'SRV':
884            target, i = self.getdomainname(msgdata,i+6)           
885            return {'priority':asctoint(rdata[0:2]),
886                    'weight':asctoint(rdata[2:4]),
887                    'port':asctoint(rdata[4:6]),
888                    'target':target}
889        elif type == 'TXT':
890            return {'txtdata':rdata[1:]}
891        else:
892            return {'rdata':rdata}
893       
894    def getrr(self, data, i):
895        log(4,'unpacking RR name')
896        name, i = self.getdomainname(data, i)
897        type = asctoint(data[i:i+2])
898        type = self.qtypes.get(type,chr(type))
899        klass = asctoint(data[i+2:i+4])
900        klass = self.qclasses.get(klass,chr(klass))
901        ttl = asctoint(data[i+4:i+8])
902        rdlength = asctoint(data[i+8:i+10])
903        rrdata = self.getrrdata(type,data,rdlength,i+10)
904        rrdata['ttl'] = ttl
905        rrdata['class'] = klass
906        rr = {name:{type:[rrdata]}}
907        return rr, i+10+rdlength
908
909    def processpkt(self, msgdata):
910        self.header.id = asctoint(msgdata[:2])
911        self.header.qr = ord(msgdata[2]) >> 7
912        self.header.opcode = (ord(msgdata[2]) & 127) >> 3
913        if self.header.opcode == 5:
914            # UPDATE packet
915            log(4,'processing UPDATE packet')
916            del self.header.aa
917            del self.header.tc
918            del self.header.rd
919            del self.header.ra
920            del self.header.qdcount
921            del self.header.ancount
922            del self.header.nscount
923            del self.header.arcount
924            del self.question
925            self.zone = dnsupdatezone()
926            del self.answerlist
927            del self.authlist
928            del self.addlist
929            self.header.z = 0
930            self.header.rcode = ord(msgdata[3]) & 15
931            self.header.zocount = asctoint(msgdata[4:6])
932            self.header.prcount = asctoint(msgdata[6:8])
933            self.header.upcount = asctoint(msgdata[8:10])
934            self.header.arcount = asctoint(msgdata[10:12])
935            self.zolist = []
936            self.prlist = []
937            self.uplist = []
938            self.addlist = []
939            i = 12
940            for x in range(self.header.zocount):
941                (dn, i) = self.getdomainname(msgdata,i)
942                self.zone.zname = dn
943                type = asctoint(msgdata[i:i+2])
944                self.zone.ztype = self.qtypes.get(type,chr(type))
945                klass = asctoint(msgdata[i+2:i+4])
946                self.zone.zclass = self.qclasses.get(klass,chr(klass))
947                i = i + 4
948            for x in range(self.header.prcount):
949                rr, i  = self.getrr(msgdata,i)
950                self.prlist.append(rr)
951            for x in range(self.header.upcount):
952                rr, i  = self.getrr(msgdata,i)
953                self.uplist.append(rr)
954            for x in range(self.header.arcount):
955                rr, i  = self.getrr(msgdata,i)
956                self.adlist.append(rr)
957        else:
958            self.header.aa = (ord(msgdata[2]) & 4) >> 2
959            self.header.tc = (ord(msgdata[2]) & 2) >> 1
960            self.header.rd = ord(msgdata[2]) & 1
961            self.header.ra = ord(msgdata[3]) >> 7
962            self.header.z = (ord(msgdata[3]) & 112) >> 4
963            self.header.rcode = ord(msgdata[3]) & 15
964            self.header.qdcount = asctoint(msgdata[4:6])
965            self.header.ancount = asctoint(msgdata[6:8])
966            self.header.nscount = asctoint(msgdata[8:10])
967            self.header.arcount = asctoint(msgdata[10:12])
968            i = 12
969            for x in range(self.header.qdcount):
970                log(4,'unpacking question')
971                (dn, i) = self.getdomainname(msgdata,i)
972                self.question.qname = dn
973                rrtype = asctoint(msgdata[i:i+2])
974                self.question.qtype = self.qtypes.get(rrtype,chr(rrtype))
975                klass = asctoint(msgdata[i+2:i+4])
976                self.question.qclass = self.qclasses.get(klass,chr(klass))
977                i = i + 4
978            for x in range(self.header.ancount):
979                log(4,'unpacking answer RR')
980                rr, i = self.getrr(msgdata,i)
981                self.answerlist.append(rr)
982            for x in range(self.header.nscount):
983                log(4,'unpacking auth RR')
984                rr, i = self.getrr(msgdata,i)           
985                self.authlist.append(rr)
986            for x in range(self.header.arcount):
987                log(4,'unpacking additional RR')
988                rr, i = self.getrr(msgdata,i)           
989                self.addlist.append(rr)
990        return
991
992    def pds(self, s, l):
993        # pad string with chr(0)'s so that
994        # return string length is l
995        x = l - len(s)
996        return x*chr(0) + s
997
998    def locsize(self, s):
999        x1 = ord(s) >> 4
1000        x2 = ord(s) & 15
1001        return (x1, x2)
1002
1003    def packlocsize(self, x):
1004        return chr((x[0] << 4) + x[1])
1005
1006    def packdomainname(self, name, i, msgcomp):
1007        log(4,'packing domainname: ' + name)
1008        if name == '':
1009            return chr(0)
1010        if name in msgcomp.keys():
1011            log(4,'using pointer for: ' + name)
1012            return msgcomp[name]
1013        packedname = ''
1014        tokens = name.split('.')
1015        for j in range(len(tokens)):
1016            packedname = packedname + chr(len(tokens[j])) + tokens[j]
1017            nameleft = '.'.join(tokens[j+1:])
1018            if nameleft in msgcomp.keys():
1019                log(4,'using pointer for: ' + nameleft)
1020                return packedname+msgcomp[nameleft]
1021        # haven't used a pointer so put this in the dictionary
1022        pointer = inttoasc(i)
1023        if len(pointer) == 1:
1024            msgcomp[name] = chr(192)+pointer
1025        else:
1026            msgcomp[name] = chr(192|ord(pointer[0])) + pointer[1]
1027        log(4,'added pointer for ' + name + '(' + str(i) + ')')
1028        return packedname + chr(0)
1029
1030    def packrr(self, rr, i, msgcomp):
1031        rrname = rr.keys()[0]
1032        rrtype = rr[rrname].keys()[0]
1033        if self.rqtypes.has_key(rrtype):
1034            typeval = self.rqtypes[rrtype]
1035        else:
1036            typeval = ord(rrtype)
1037        dbrec = rr[rrname][rrtype][0]
1038        ttl = dbrec['ttl']
1039        rclass = self.rqclasses[dbrec['class']]
1040        packedrr = (self.packdomainname(rrname, i, msgcomp) +
1041                    self.pds(inttoasc(typeval),2) +
1042                    self.pds(inttoasc(rclass),2) +
1043                    self.pds(inttoasc(ttl),4))
1044        i = i + len(packedrr) + 2
1045        if rrtype == 'A':
1046            rdata = socket.inet_aton(dbrec['address'])
1047        elif rrtype == 'AAAA':
1048            rdata = ipv6net_aton(dbrec['address'])
1049        elif rrtype == 'CNAME':
1050            rdata = self.packdomainname(dbrec['cname'], i, msgcomp)
1051        elif rrtype == 'HINFO':
1052            rdata = (chr(len(dbrec['cpu'])) + dbrec['cpu'] +
1053                     chr(len(dbrec['os'])) + dbrec['os'])
1054        elif rrtype == 'LOC':
1055            rdata = (chr(dbrec['version']) +
1056                     self.packlocsize(dbrec['size']) +
1057                     self.packlocsize(dbrec['horiz_pre']) +
1058                     self.packlocsize(dbrec['vert_pre']) +
1059                     self.pds(inttoasc(dbrec['latitude']),4) +
1060                     self.pds(inttoasc(dbrec['longitude']),4) +
1061                     self.pds(inttoasc(dbrec['altitude']),4))
1062        elif rrtype == 'MX':
1063            rdata = (self.pds(inttoasc(dbrec['preference']),2) +
1064                     self.packdomainname(dbrec['exchange'], i+2, msgcomp))
1065        elif rrtype == 'NS':
1066            rdata = self.packdomainname(dbrec['nsdname'], i, msgcomp)
1067        elif rrtype == 'PTR':
1068            rdata = self.packdomainname(dbrec['ptrdname'], i, msgcomp)
1069        elif rrtype == 'RP':
1070            rdata1 = self.packdomainname(dbrec['mboxdname'], i , msgcomp)
1071            i = i + len(rdata1)
1072            rdata2 = self.packdomainname(dbrec['mboxdname'], i , msgcomp)
1073            rdata = rdata1 + rdata2
1074        elif rrtype == 'SOA':
1075            rdata1 = self.packdomainname(dbrec['mname'], i, msgcomp)
1076            i = i + len(rdata1)
1077            rdata2 = self.packdomainname(dbrec['rname'], i, msgcomp)
1078            rdata = (rdata1 +
1079                     rdata2 +
1080                     self.pds(inttoasc(dbrec['serial']),4) +
1081                     self.pds(inttoasc(dbrec['refresh']),4) +
1082                     self.pds(inttoasc(dbrec['retry']),4) +
1083                     self.pds(inttoasc(dbrec['expire']),4) +
1084                     self.pds(inttoasc(dbrec['minimum']),4))
1085        elif rrtype == 'SRV':
1086            rdata = (self.pds(inttoasc(dbrec['priority']),2) +
1087                     self.pds(inttoasc(dbrec['weight']),2) +
1088                     self.pds(inttoasc(dbrec['port']),2) +
1089                     self.packdomainname(dbrec['target'], i+6, msgcomp))
1090        elif rrtype == 'TXT':
1091            rdata = chr(len(dbrec['txtdata'])) + dbrec['txtdata']
1092        else:
1093            rdata = dbrec['rdata']
1094
1095        return packedrr+self.pds(inttoasc(len(rdata)),2)+rdata
1096
1097    def buildpkt(self):
1098        # keep dictionary of names packed (so we can use pointers)
1099        msgcomp = {}
1100        # header
1101        if self.header.id > 65535:
1102            log(0,'building packet with bad ID field')
1103            self.header.id = 1
1104        msgdata = inttoasc(self.header.id)
1105        if len(msgdata) == 1:
1106            msgdata = chr(0) + msgdata
1107        h1 = ((self.header.qr << 7) +
1108              (self.header.opcode << 3) +
1109              (self.header.aa << 2) +
1110              (self.header.tc << 1) +
1111              (self.header.rd))
1112        h2 = ((self.header.ra << 7) +
1113              (self.header.z << 4) +
1114              (self.header.rcode))
1115        msgdata = msgdata + chr(h1) + chr(h2)
1116        msgdata = msgdata + self.pds(inttoasc(self.header.qdcount),2)
1117        msgdata = msgdata + self.pds(inttoasc(self.header.ancount),2)
1118        msgdata = msgdata + self.pds(inttoasc(self.header.nscount),2)
1119        msgdata = msgdata + self.pds(inttoasc(self.header.arcount),2)
1120        # question
1121        msgdata = msgdata + self.packdomainname(self.question.qname, len(msgdata), msgcomp)
1122        if self.rqtypes.has_key(self.question.qtype):
1123            typeval = self.rqtypes[self.question.qtype]
1124        else:
1125            typeval = ord(self.question.qtype)
1126        msgdata = msgdata + self.pds(inttoasc(typeval),2)
1127        if self.rqclasses.has_key(self.question.qclass):
1128            classval = self.rqclasses[self.question.qclass]
1129        else:
1130            classval = ord(self.question.qclass)
1131        msgdata = msgdata + self.pds(inttoasc(classval),2)
1132        # rr's
1133        # RR record format:
1134        # {'name' : {'type' : [rdata, rdata, ...]}
1135        # example: {'test.blah.net': {'A': [{'address': '10.1.1.2',
1136        #                                    'ttl': 3600L}]}}
1137        for rr in self.answerlist:
1138            log(4,'packing answer RR')
1139            msgdata = msgdata + self.packrr(rr, len(msgdata), msgcomp)
1140        for rr in self.authlist:
1141            log(4,'packing auth RR')
1142            msgdata = msgdata + self.packrr(rr, len(msgdata), msgcomp)
1143        for rr in self.addlist:
1144            log(4,'packing additional RR')
1145            msgdata = msgdata + self.packrr(rr, len(msgdata), msgcomp)
1146           
1147        return msgdata
1148
1149    def printpkt(self):
1150        print 'ID: ' +str(self.header.id)
1151        if self.header.qr:
1152            print 'QR: RESPONSE'
1153        else:
1154            print 'QR: QUERY'
1155        if self.header.opcode == 0:
1156            print 'OPCODE: STANDARD QUERY'
1157        elif self.header.opcode == 1:
1158            print 'OPCODE: INVERSE QUERY'
1159        elif self.header.opcode == 2:
1160            print 'OPCODE: SERVER STATUS REQUEST'
1161        elif self.header.opcode == 5:
1162            print 'UPDATE REQUEST'
1163        else:
1164            print 'OPCODE: UNKNOWN QUERY TYPE'
1165        if self.header.opcode != 5:
1166            if self.header.aa:
1167                print 'AA: AUTHORITATIVE ANSWER'
1168            else:
1169                print 'AA: NON-AUTHORITATIVE ANSWER'
1170            if self.header.tc:
1171                print 'TC: MESSAGE IS TRUNCATED'
1172            else:
1173                print 'TC: MESSAGE IS NOT TRUNCATED'
1174            if self.header.rd:
1175                print 'RD: RECURSION DESIRED'
1176            else:
1177                print 'RD: RECURSION NOT DESIRED'
1178            if self.header.ra:
1179                print 'RA: RECURSION AVAILABLE'
1180            else:
1181                print 'RA: RECURSION IS NOT AVAILABLE'
1182        if self.header.rcode == 1:
1183            printrcode =  'FORMERR'
1184        elif self.header.rcode == 2:
1185            printrcode =  'SERVFAIL'
1186        elif self.header.rcode == 3:
1187            printrcode =  'NXDOMAIN'
1188        elif self.header.rcode == 4:
1189            printrcode =  'NOTIMP'
1190        elif self.header.rcode == 5:
1191            printrcode =  'REFUSED'
1192        elif self.header.rcode == 6:
1193            printrcode = 'YXDOMAIN'
1194        elif self.header.rcode == 7:
1195            printrcode = 'YXRRSET'
1196        elif self.header.rcode == 8:
1197            printrcode = 'NXRRSET'
1198        elif self.header.rcode == 9:
1199            printrcode = 'NOTAUTH'
1200        elif self.header.rcode == 10:
1201            printrcode = 'NOTZONE'
1202        else:
1203            printrcode =  'NOERROR'
1204        print 'RCODE: ' + printrcode
1205        if self.header.opcode == 5:
1206            print 'NUMBER OF RRs in the Zone Section: ' + str(self.header.zocount)
1207            print 'NUMBER OF RRs in the Prerequisite Section: ' + str(self.header.prcount)
1208            print 'NUMBER OF RRs in the Update Section: ' + str(self.header.upcount)
1209            print 'NUMBER OF RRs in the Additional Data Section: ' + str(self.header.arcount)
1210            print 'ZONE SECTION:'
1211            print 'zname: ' + self.zone.zname
1212            print 'zonetype: ' + self.zone.ztype
1213            print 'zoneclass: ' + self.zone.zclass
1214            print 'PREREQUISITE RRs:'
1215            for rr in self.prlist:
1216                print rr
1217            print 'UPDATE RRs:'       
1218            for rr in self.uplist:
1219                print rr
1220            print 'ADDITIONAL RRs:'       
1221            for rr in self.addlist:
1222                print rr
1223
1224
1225        else:
1226            print 'NUMBER OF QUESTION RRs: ' + str(self.header.qdcount)
1227            print 'NUMBER OF ANSWER RRs: ' + str(self.header.ancount)
1228            print 'NUMBER OF NAME SERVER RRs: ' + str(self.header.nscount)
1229            print 'NUMBER OF ADDITIONAL RRs: ' + str(self.header.arcount)
1230            print 'QUESTION SECTION:'
1231            print 'qname: ' + self.question.qname
1232            print 'querytype: ' + self.question.qtype
1233            print 'queryclass: ' + self.question.qclass
1234            print 'ANSWER RRs:'
1235            for rr in self.answerlist:
1236                print rr
1237            print 'AUTHORITY RRs:'       
1238            for rr in self.authlist:
1239                print rr
1240            print 'ADDITIONAL RRs:'       
1241            for rr in self.addlist:
1242                print rr
1243
1244class zonedb:
1245    def __init__(self, zdict):
1246        self.zdict = zdict
1247        self.updates = {}
1248        for k in self.zdict.keys():
1249            if self.zdict[k]['type'] == 'slave':
1250                self.zdict[k]['lastupdatetime'] = 0
1251
1252    def error(self, id, qname, querytype, queryclass, rcode):
1253        error = message()
1254        error.header.id = id
1255        error.header.rcode = rcode
1256        error.header.qr = 1
1257        error.question.qname = qname
1258        error.question.qtype = querytype
1259        error.question.qclass = queryclass
1260        return error
1261
1262    def getorigin(self, zkey):
1263        origin = ''
1264        if self.zdict.has_key(zkey):
1265            origin = self.zdict[zkey]['origin']
1266        return origin
1267
1268    def getmasterip(self, zkey):
1269        masterip = ''
1270        if self.zdict.has_key(zkey):
1271            if self.zdict[zkey].has_key('masterip'):
1272                masterip = self.zdict[zkey]['masterip']
1273        return masterip
1274
1275    def zonetrans(self, query):
1276        # build a list of messages
1277        # each message contains one rr of the zone
1278        # the first and last message are the
1279        # SOA records
1280        origin = query.question.qname
1281        querytype = query.question.qtype
1282        zkey = ''
1283        for zonekey in self.zdict.keys():
1284            if self.zdict[zonekey]['origin'] == query.question.qname:
1285                zkey = zonekey
1286        if not zkey:
1287            return []
1288        zonedata = self.zdict[zkey]['zonedata']
1289        queryid = query.header.id
1290        soarec = zonedata[origin]['SOA'][0]
1291        soa = {origin:{'SOA':[soarec]}}
1292        curserial = soarec['serial']
1293        rrlist = []
1294        if querytype == 'IXFR':
1295            clientserial = query.authlist[0][origin]['SOA'][0]['serial']
1296            if clientserial < curserial:
1297                for i in range(clientserial,curserial+1):
1298                    if self.updates[zkey].has_key(i):
1299                        for rr in self.updates[zkey][i]['added']:
1300                            rrlist.append(rr)
1301                        for rr in self.updates[zkey][i]['removed']:
1302                            rrlist.append(rr)
1303                if len(rrlist) > 0:
1304                    rrlist.insert(0,soa)
1305                rrlist.append(soa)
1306            else:
1307                rrlist.append(soa)
1308        else:
1309            for nodename in zonedata.keys():
1310                for rrtype in zonedata[nodename].keys():
1311                    if not (rrtype == 'SOA' and nodename == origin):
1312                        for rr in zonedata[nodename][rrtype]:
1313                            rrlist.append({nodename:{rrtype:[rr]}})
1314            rrlist.insert(0,soa)
1315            rrlist.append(soa)
1316        msglist = []
1317        for rr in rrlist:
1318            msg = message()
1319            msg.header.id = queryid
1320            msg.header.qr = 1
1321            msg.header.aa = 1
1322            msg.header.rd = 0
1323            msg.header.qdcount = 1
1324            msg.question.qname = origin
1325            msg.question.qtype = querytype
1326            msg.question.qclass = 'IN'
1327            msg.header.ancount = 1
1328            msg.answerlist.append(rr)
1329            msglist.append(msg)
1330        return msglist
1331
1332    def update_zone(self, rrlist, params):
1333        zonekey = params[0]
1334        zonedata = {}
1335        soa = rrlist.pop()
1336        origin = soa.keys()[0]
1337        for rr in rrlist:
1338            rrname = rr.keys()[0]
1339            rrtype = rr[rrname].keys()[0]
1340            dbrec = rr[rrname][rrtype][0]
1341            if zonedata.has_key(rrname):
1342                if not zonedata[rrname].has_key(rrtype):
1343                    zonedata[rrname][rrtype] = []
1344            else:
1345                zonedata[rrname] = {}
1346                zonedata[rrname][rrtype] = []
1347            zonedata[rrname][rrtype].append(dbrec)
1348        self.zdict[zonekey]['zonedata'] = zonedata
1349        curtime = time.time()
1350        self.zdict[zonekey]['lastupdatetime'] = curtime
1351        try:
1352            f = file(self.zdict[zonekey]['filename'],'w')
1353            writezonefile(zonedata, self.zdict[zonekey]['origin'], f)
1354            f.close()
1355        except:
1356            log(0,'unable to write zone ' + zonekey + 'to disk')
1357        log(1,'finished zone transfer for: ' + zonekey + ' (' + str(curtime) + ')')
1358
1359    def remove_zone(self, zonekey):
1360        if self.zdict.has_key(zonekey):
1361            del self.zdict[zonekey]
1362
1363    def getslaves(self, curtime):
1364        rlist = []
1365        for k in self.zdict.keys():
1366            if self.zdict[k]['type'] == 'slave':
1367                origin = self.zdict[k]['origin']
1368                refresh = self.zdict[k]['zonedata'][origin]['SOA'][0]['refresh']
1369                if self.zdict[k]['lastupdatetime'] + refresh < curtime:
1370                    rlist.append((k, origin, self.zdict[k]['masterip']))
1371        return rlist
1372
1373    def zmatch(self, qname, zkeys):
1374        for zkey in zkeys:
1375            if self.zdict.has_key(zkey):
1376                origin = self.zdict[zkey]['origin']
1377                if qname.rfind(origin) != -1:
1378                    return zkey
1379        return ''
1380
1381    def getzlist(self, name, zone):
1382        if name == zone:
1383            return
1384        zlist = []
1385        i = name.rfind(zone)
1386        if i == -1:
1387            return
1388        firstpart = name[:i-1]
1389        partlist = firstpart.split('.')
1390        partlist.reverse()
1391        lastpart = zone
1392        for x in range(len(partlist)):
1393            lastpart = partlist[x] + '.' + lastpart
1394            zlist.append(lastpart)
1395        return zlist
1396
1397    def lookup(self, zkeys, query, addr, server, dorecursion, flist, cbfunc):
1398        # handle zone transfers seperately
1399        qname = query.question.qname
1400        querytype = query.question.qtype
1401        queryclass = query.question.qclass
1402        if querytype in ['AXFR','IXFR']:
1403            for zkey in self.zdict.keys():
1404                if zkey in zkeys:
1405                    if qname == self.zdict[zkey]['origin']:
1406                        answerlist = self.zonetrans(query)
1407                        break
1408            else:
1409                answerlist = []
1410            cbfunc(query, addr, server, dorecursion, flist, answerlist)
1411        else:
1412            zonekey = self.zmatch(qname, zkeys)
1413            if zonekey:
1414                origin = self.zdict[zonekey]['origin']
1415                zonedict = self.zdict[zonekey]['zonedata']
1416                referral = 0
1417                rranswerlist = []
1418                rrnslist = []
1419                rraddlist = []
1420                answer = message()
1421                answer.header.aa = 1
1422                answer.header.id = query.header.id
1423                answer.header.qr = 1
1424                answer.header.opcode = query.header.opcode
1425                answer.header.rcode = 4
1426                answer.header.ra = dorecursion
1427                answer.question.qname = query.question.qname
1428                answer.question.qtype = query.question.qtype
1429                answer.question.qclass = query.question.qclass
1430                answer.header.ra = dorecursion
1431                s = '.servers.csail.mit.edu'
1432                if qname.endswith(s):
1433                    host = qname[:-len(s)]
1434                    value = sipb_xen_database.NIC.get_by(hostname=host)
1435                    if value is None:
1436                        pass
1437                    else:
1438                        ip = value.ip
1439                        rranswerlist.append({qname: {'A': [{'address': ip, 
1440                                                            'class': 'IN', 
1441                                                            'ttl': 10}]}})
1442                if zonedict.has_key(qname):
1443                    # found the node, now take care of CNAMEs
1444                    if zonedict[qname].has_key('CNAME'):
1445                        if querytype != 'CNAME':
1446                            nodetype = 'CNAME'
1447                            while nodetype == 'CNAME':
1448                                rranswerlist.append({qname:{'CNAME':[zonedict[qname]['CNAME'][0]]}})
1449                                qname = zonedict[qname]['CNAME'][0]['cname']
1450                                if zonedict.has_key(qname):
1451                                    nodetype = zonedict[qname].keys()[0]
1452                                else:
1453                                    # error, shouldn't have a CNAME that points to nothing
1454                                    return
1455                    # if we get this far, then the record has matched and we should return
1456                    # a reply that has no error (even if there is no info macthing the qtype)
1457                    answer.header.rcode = 0
1458                    answernode = zonedict[qname]
1459                    if querytype == 'ANY':
1460                        for type in answernode.keys():
1461                            for rec in answernode[type]:
1462                                rranswerlist.append({qname:{type:[rec]}})
1463                    elif answernode.has_key(querytype):
1464                        for rec in answernode[querytype]:
1465                            rranswerlist.append({qname:{querytype:[rec]}})
1466                        # do rrset ordering (cyclic)
1467                        if len(answernode[querytype]) > 1:
1468                            rec = answernode[querytype].pop(0)
1469                            answernode[querytype].append(rec)
1470                    else:
1471                        # remove all cname rrs from answerlist
1472                        rranswerlist = []
1473                else:
1474                    # would check for wildcards here (but aren't because they seem bad)
1475                    # see if we need to give a referral
1476                    zlist = self.getzlist(qname,origin)
1477                    for zonename in zlist:
1478                        if zonedict.has_key(zonename):
1479                            if zonedict[zonename].has_key('NS'):
1480                                answer.header.rcode = 0
1481                                referral = 1
1482                                for rec in zonedict[zonename]['NS']:
1483                                    rrnslist.append({zonename:{'NS':[rec]}})
1484                                    nsdname = rec['nsdname']
1485                                    # add glue records if they exist
1486                                    if zonedict.has_key(nsdname):
1487                                        if zonedict[nsdname].has_key('A'):
1488                                            for gluerec in zonedict[nsdname]['A']:
1489                                                rraddlist.append({nsdname:{'A':[gluerec]}})
1490                    # negative caching stuff
1491                    if not referral:
1492                        if not rranswerlist:
1493                            # NOTE: RFC1034 section 4.3.4 says we should add the SOA record
1494                            #       to the additional section of the response.  BIND adds
1495                            #       it to the ns section though
1496                            answer.header.rcode = 3
1497                            rrnslist.append({origin:{'SOA':[zonedict[origin]['SOA'][0]]}})
1498                        else:
1499                            for rec in zonedict[origin]['NS']:
1500                                rrnslist.append({origin:{'NS':[rec]}})
1501                answer.header.ancount = len(rranswerlist)
1502                answer.header.nscount = len(rrnslist)
1503                answer.header.arcount = len(rraddlist)
1504                answer.answerlist = rranswerlist
1505                answer.authlist = rrnslist
1506                answer.addlist = rraddlist
1507                cbfunc(query, addr, server, dorecursion, flist, [answer])
1508            else:
1509                cbfunc(query, addr, server, dorecursion, flist, [])
1510
1511    def handle_update(self, msg, addr, ns):
1512        zkey = ''
1513        slaves = []
1514        for zonekey in self.zdict.keys():
1515            if (self.zdict[zonekey]['type'] == 'master' and
1516                self.zdict[zonekey]['origin'] == msg.zone.zname):
1517                zkey = zonekey
1518        if not zkey:
1519            log(2,'SENDING NOTAUTH UPDATE ERROR')
1520            errormsg = self.error(msg.header.id, msg.zone.zname,
1521                                  msg.zone.ztype, msg.zone.zclass, 9)
1522            return errormsg, '', slaves
1523        # find the slaves for the zone
1524        if self.zdict[zkey].has_key('slaves'):
1525            slaves = self.zdict[zkey]['slaves']
1526        origin = self.zdict[zkey]['origin']
1527        zd = self.zdict[zkey]['zonedata']
1528        # check the permissions
1529        if not ns.config.allowupdate(msg, addr[0], addr[1]):
1530            log(2,'SENDING REFUSED UPDATE ERROR')
1531            errormsg = self.error(msg.header.id, msg.zone.zname,
1532                                  msg.zone.ztype, msg.zone.zclass, 5)
1533            return errormsg, origin, slaves
1534        # now check the prereqs
1535        temprrset = {}
1536        for rr in msg.prlist:
1537            rrname = rr.keys()[0]
1538            rrtype = rr[rrname].keys()[0]
1539            dbrec = rr[rrname][rrtype][0]
1540            if dbrec['ttl'] != 0:
1541                log(2,'FORMERROR(1)')
1542                errormsg = self.error(msg.header.id, msg.zone.zname,
1543                                      msg.zone.ztype, msg.zone.zclass, 1)
1544                return errormsg, origin, slaves
1545            if rrname.rfind(msg.zone.zname) == -1:
1546                log(2,'NOTZONE(10)')
1547                errormsg = self.error(msg.header.id, msg.zone.zname,
1548                                      msg.zone.ztype, msg.zone.zclass, 10)
1549                return errormsg, origin, slaves
1550            if dbrec['class'] == 'ANY':
1551                if dbrec['rdata']:
1552                    log(2,'FORMERROR(1)')
1553                    errormsg = self.error(msg.header.id, msg.zone.zname,
1554                                          msg.zone.ztype, msg.zone.zclass, 1)
1555                    return errormsg, origin, slaves
1556                if rrtype == 'ANY':
1557                    if not zd.has_key(rrname):
1558                        log(2,'NXDOMAIN(3)')
1559                        errormsg = self.error(msg.header.id, msg.zone.zname,
1560                                              msg.zone.ztype, msg.zone.zclass, 3)
1561                        return errormsg, origin, slaves
1562                else:
1563                    rrsettest = 0
1564                    if zd.has_key(rrname):
1565                        if zd[rrname].has_key(rrtype):
1566                            rrsettest = 1
1567                    if not rrsettest:
1568                        log(2,'NXRRSET(8)')
1569                        errormsg = self.error(msg.header.id, msg.zone.zname,
1570                                              msg.zone.ztype, msg.zone.zclass, 8)
1571                        return errormsg, origin, slaves
1572            if dbrec['class'] == 'NONE':
1573                if dbrec['rdata']:
1574                    log(2,'FORMERROR(1)')
1575                    errormsg = self.error(msg.header.id, msg.zone.zname,
1576                                          msg.zone.ztype, msg.zone.zclass, 1)
1577                    return errormsg, origin, slaves
1578                if rrtype == 'ANY':
1579                    if zd.has_key(rrname):
1580                        log(2,'YXDOMAIN(6)')
1581                        errormsg = self.error(msg.header.id, msg.zone.zname,
1582                                              msg.zone.ztype, msg.zone.zclass, 6)
1583                        return errormsg, origin, slaves
1584                else:
1585                    if zd.has_key(rrname):
1586                        if zd[rrname].has_key(rrtype):
1587                            log(2,'YXRRSET(7)')
1588                            errormsg = self.error(msg.header.id, msg.zone.zname,
1589                                                  msg.zone.ztype, msg.zone.zclass, 7)
1590                            return errormsg, origin, slaves
1591            if dbrec['class'] == msg.zone.zclass:
1592                if temprrset.has_key(rrname):
1593                    if not temprrset[rrname].has_key(rrtype):
1594                        temprrset[rrname][rrtype] = []
1595                else:
1596                    temprrset[rrname] = {}
1597                    temprrset[rrname][rrtype] = []
1598                temprrset[rrname][rrtype].append(dbrec)
1599            else:
1600                log(2,'FORMERROR(1)')
1601                errormsg = self.error(msg.header.id, msg.zone.zname,
1602                                      msg.zone.ztype, msg.zone.zclass, 1)
1603                return errormsg, origin, slaves
1604        for nodename in temprrset.keys():
1605            if not self.rrmatch(temprrset[nodename],zd[nodename]):
1606                log(2,'NXRRSET(8)')
1607                errormsg = self.error(msg.header.id, msg.zone.zname,
1608                                      msg.zone.ztype, msg.zone.zclass, 8)
1609                return errormsg, origin, slaves
1610
1611        # update section prescan
1612        for rr in msg.uplist:
1613            rrname = rr.keys()[0]
1614            rrtype = rr[rrname].keys()[0]
1615            dbrec = rr[rrname][rrtype][0]
1616            if rrname.rfind(msg.zone.zname) == -1:
1617                log(2,'NOTZONE(10)')
1618                errormsg = self.error(msg.header.id, msg.zone.zname,
1619                                      msg.zone.ztype, msg.zone.zclass, 10)
1620                return errormsg, origin, slaves
1621            if dbrec['class'] == msg.zone.zclass:
1622                if rrtype in ['ANY','MAILA','MAILB','AXFR']:
1623                    log(2,'FORMERROR(1)')
1624                    errormsg = self.error(msg.header.id, msg.zone.zname,
1625                                          msg.zone.ztype, msg.zone.zclass, 1)
1626                    return errormsg, origin, slaves
1627            elif dbrec['class'] == 'ANY':
1628                if dbrec['ttl'] != 0 or dbrec['rdata'] or rrtype in ['MAILA','MAILB','AXFR']:
1629                    log(2,'FORMERROR(1)')
1630                    errormsg = self.error(msg.header.id, msg.zone.zname,
1631                                          msg.zone.ztype, msg.zone.zclass, 1)
1632                    return errormsg, origin, slaves
1633            elif dbrec['class'] == 'NONE':
1634                if dbrec['ttl'] != 0 or rrtype in ['ANY','MAILA','MAILB','AXFR']:
1635                    log(2,'FORMERROR(1)')
1636                    errormsg = self.error(msg.header.id, msg.zone.zname,
1637                                          msg.zone.ztype, msg.zone.zclass, 1)
1638                    return errormsg, origin, slaves
1639            else:
1640                log(2,'FORMERROR(1)')
1641                errormsg = self.error(msg.header.id, msg.zone.zname,
1642                                      msg.zone.ztype, msg.zone.zclass, 1)
1643                return errormsg, origin, slaves
1644
1645        # now handle actual update
1646        curserial = zd[msg.zone.zname]['SOA'][0]['serial']
1647        # update the soa serial here
1648        clearupdatehist = 0
1649        if len(msg.uplist) > 0:
1650            # initialize history structure
1651            if not self.updates.has_key(zkey):
1652                self.updates[zkey] = {}
1653                self.updates[zkey][curserial] = {'removed':[],
1654                                                 'added':[]}
1655            if curserial == 2**32:
1656                newserial = 2
1657                clearupdatehist = 1
1658            else:
1659                newserial = curserial + 1
1660            self.updates[zkey][newserial] = {'removed':[],
1661                                             'added':[]}
1662            zd[msg.zone.zname]['SOA'][0]['serial'] = newserial
1663        for rr in msg.uplist:
1664            rrname = rr.keys()[0]
1665            rrtype = rr[rrname].keys()[0]
1666            dbrec = rr[rrname][rrtype][0]
1667            if dbrec['class'] == msg.zone.zclass:
1668                if rrtype == 'SOA':
1669                    if zd.has_key(rrname):
1670                        if zd[rrname].has_key('SOA'):
1671                            if dbrec['serial'] > zd[rrname]['SOA'][0]['serial']:
1672                                del zd[rrname]['SOA'][0]
1673                                zd[rrname]['SOA'].append(dbrec)
1674                                clearupdatehist = 1
1675                elif rrtype == 'WKS':
1676                    if zd.has_key(rrname):
1677                        if zd[rrname].has_key('WKS'):
1678                            rdata = zd[rrname]['WKS'][0]
1679                            oldrr = {rrname:{'WKS':[rdata]}}
1680                            self.updates[zkey][curserial]['removed'].append(oldrr)
1681                            del zd[rrname]['WKS'][0]
1682                            zd[rrname]['WKS'].append(dbrec)
1683                            newrr = {rrname:{'WKS':[dbrec]}}
1684                            self.updates[zkey][newserial]['added'].append(newrr)
1685                else:
1686                    if zd.has_key(rrname):
1687                        if not zd[rrname].has_key(rrtype):
1688                            zd[rrname][rrtype] = []
1689                    else:
1690                        zd[rrname] = {}
1691                        zd[rrname][rrtype] = []
1692                    zd[rrname][rrtype].append(dbrec)
1693                    newrr = {rrname:{rrtype:[dbrec]}}
1694                    self.updates[zkey][newserial]['added'].append(newrr)
1695            elif dbrec['class'] == 'ANY':
1696                if rrtype == 'ANY':
1697                    if rrname == msg.zone.zname:
1698                        if zd.has_key(rrname):
1699                            for dnstype in zd[rrname].keys():
1700                                if dnstype not in ['SOA','NS']:
1701                                    for rdata in zd[rrname][dnstype]:
1702                                        oldrr = {rrname:{dnstype:[rdata]}}
1703                                        self.updates[zkey][curserial]['removed'].append(oldrr)
1704                                    del zd[rrname][dnstype]                                   
1705                    else:
1706                        if zd.has_key(rrname):
1707                            for dnstype in zd[rrname].keys():
1708                                for rdata in zd[rrname][dnstype]:
1709                                    oldrr = {rrname:{dnstype:[rdata]}}
1710                                    self.updates[zkey][curserial]['removed'].append(oldrr)
1711                            del zd[rrname]
1712                else:
1713                    if zd.has_key(rrname):
1714                        if zd[rrname].has_key(rrtype):
1715                            if rrname == msg.zone.zname:
1716                                if rrtype not in ['SOA','NS']:
1717                                    for rdata in zd[rrname][dnstype]:
1718                                        oldrr = {rrname:{dnstype:[rdata]}}
1719                                        self.updates[zkey][curserial]['removed'].append(oldrr)
1720                                    del zd[rrname][rrtype]
1721                            else:
1722                                for rdata in zd[rrname][dnstype]:
1723                                    oldrr = {rrname:{dnstype:[rdata]}}
1724                                    self.updates[zkey][curserial]['removed'].append(oldrr)
1725                                del zd[rrname][rrtype]
1726            elif dbrec['class'] == 'NONE':
1727                if not (rrname == msg.zone.zname and rrtype in ['SOA','NS']):
1728                    if zd.had_key(rrname):
1729                        if zd[rrname].has_key(rrtype):
1730                            for i in range(len(zd[rrname][rrtype])):
1731                                if dbrec == zd[rrname][rrtype][i]:
1732                                    rdata = zd[rrname][dnstype][i]
1733                                    oldrr = {rrname:{dnstype:[rdata]}}
1734                                    self.updates[zkey][curserial]['removed'].append(oldrr)
1735                                    del zd[rrname][rrtype][i]
1736                            if len(zd[rrname][rrtype]) == 0:
1737                                del zd[rrname][rrtype]
1738        if clearupdatehist:
1739            self.updates[zkey] = {}
1740        log(2,'SENDING UPDATE NOERROR MSG')
1741        noerrormsg = self.error(msg.header.id, msg.zone.zname,
1742                              msg.zone.ztype, msg.zone.zclass, 0)
1743        return noerrormsg, origin, slaves
1744
1745class dnscache:
1746    def __init__(self,cachezone):
1747        self.cachedb = cachezone
1748        # go through and set all of the root ttls to zero
1749        for node in self.cachedb.keys():
1750            for rtype in self.cachedb[node].keys():
1751                for rr in self.cachedb[node][rtype]:
1752                    rr['ttl'] = 0
1753                    if rtype == 'NS':
1754                        rr['rtt'] = 0
1755        # add special entries for localhost
1756        self.cachedb['localhost'] = {'A':[{'address':'127.0.0.1', 'ttl':0, 'class':'IN'}]}
1757        self.cachedb['1.0.0.127.in-addr.arpa'] = {'PTR':[{'ptrdname':'localhost', 'ttl':0,'class':'IN'}]}
1758        self.cachedb['']['SOA'] = []
1759        self.cachedb['']['SOA'].append({'class':'IN','ttl':0,'mname':'cachedb',
1760                                        'rname':'cachedb@localhost','serial':1,'refresh':10800,
1761                                        'retry':3600,'expire':604800,'minimum':3600})
1762
1763    def hasrdata(self, irrdata, rrdatalist):
1764        # compare everything but ttls
1765        test = 0
1766        testrrdata = irrdata.copy()
1767        del testrrdata['ttl']
1768        for rrdata in rrdatalist:
1769            temprrdata = rrdata.copy()
1770            del temprrdata['ttl']
1771            if temprrdata == testrrdata:
1772                test = 1
1773        return test
1774
1775    def add(self, rr, qzone, nsdname):
1776        # NOTE: can't cache records from sites
1777        # that don't own those records (i.e. example.com
1778        # can't give us A records for www.example.net)
1779        name = rr.keys()[0]
1780        if (qzone != '') and (name[-len(qzone):] != qzone):
1781            log(2,'cache GOT possible POISON: ' + name + ' for zone ' + qzone)
1782            return
1783        rtype = rr[name].keys()[0]
1784        rdata = rr[name][rtype][0]
1785        if rdata['ttl'] < 3600:
1786            log(2,'low ttl: ' + str(rdata['ttl']))
1787            rdata['ttl'] = 3600
1788        rdata['ttl'] = int(time.time() + rdata['ttl'])
1789        if rtype == 'NS':
1790            rdata['rtt'] = 0
1791        name = name.lower()
1792        rtype = rtype.upper()
1793        if self.cachedb.has_key(name):
1794            if self.cachedb[name].has_key(rtype):
1795                if not self.hasrdata(rdata, self.cachedb[name][rtype]):
1796                    self.cachedb[name][rtype].append(rdata)
1797                    log(3,'appended rdata to ' +
1798                        name + '(' + rtype + ') in cache')
1799                else:
1800                    log(3,'same rdata for ' + name + '(' +
1801                        rtype + ') is already in cache')
1802            else:
1803                self.cachedb[name][rtype] = [rdata]
1804                log(3,'appended ' + rtype + ' and rdata to node ' +
1805                    name + ' in cache')
1806        else:
1807            self.cachedb[name] = {rtype:[rdata]}
1808            log(3,'added node ' + name + '(' + rtype + ') to cache')
1809        self.reap()
1810
1811    def addneg(self, qname, querytype, queryclass):
1812        if not self.cachedb.has_key(qname):
1813            self.cachedb['qname'] = {querytype: [{'ttl':time.time()+3600}]}
1814        else:
1815            if not self.cachedb[qname].has_key(querytype):
1816                self.cachedb[qname][querytype] = [{'ttl':time.time()+3600}]
1817   
1818    def haskey(self, qname, querytype, msg=''):
1819        log(3,'looking for ' + qname + '(' + querytype + ') in cache')
1820        if self.cachedb.has_key(qname):
1821            rranswerlist = []
1822            rrnslist = []
1823            rraddlist = []
1824            if self.cachedb[qname].has_key('CNAME'):
1825                if querytype != 'CNAME':
1826                    nodetype = 'CNAME'
1827                    while nodetype == 'CNAME':
1828                        if len(self.cachedb[qname]['CNAME'][0].keys()) > 1:
1829                            log(3,'Adding CNAME to cache answer')
1830                            rranswerlist.append({qname:{'CNAME':[self.cachedb[qname]['CNAME'][0]]}})
1831                        qname = self.cachedb[qname]['CNAME'][0]['cname']
1832                        if self.cachedb.has_key(qname):
1833                            nodetype = self.cachedb[qname].keys()[0]
1834                        else:
1835                            # shouldn't have a CNAME that points to nothing
1836                            return
1837            if querytype == 'ANY':
1838                for type in self.cache[qname].keys():
1839                    for rec in self.cachedb[qname][type]:
1840                        # can't append negative entries
1841                        if len(rec.keys()) > 1:
1842                            rranswerlist.append({qname:{type:[rec]}})
1843            elif self.cachedb[qname].has_key(querytype):
1844                for rec in self.cachedb[qname][querytype]:
1845                    if len(rec.keys()) > 1:
1846                        rranswerlist.append({qname:{querytype:[rec]}})
1847            if rranswerlist:
1848                if msg:
1849                    answer = message()
1850                    answer.header.id = msg.header.id
1851                    answer.header.qr = 1
1852                    answer.header.opcode = msg.header.opcode
1853                    answer.header.ra = 1
1854                    answer.question.qname = msg.question.qname
1855                    answer.question.qtype = msg.question.qtype
1856                    answer.question.qclass = msg.question.qclass
1857                    answer.header.rcode = 0
1858                    answer.header.ancount = len(rranswerlist)
1859                    answer.answerlist = rranswerlist
1860                    return answer
1861                else:
1862                    return 1
1863        else:
1864            log(3,'Cache has no node for ' + qname)
1865       
1866    def getnslist(self, qname):
1867        # find the best nameserver to ask from the cache
1868        tokens = qname.split('.')
1869        nsdict = {}
1870        curtime = time.time()
1871        for i in range(len(tokens)):
1872            domainname = '.'.join(tokens[i:])
1873            if self.cachedb.has_key(domainname):
1874                if self.cachedb[domainname].has_key('NS'):
1875                    for nsrec in self.cachedb[domainname]['NS']:
1876                        badserver = 0
1877                        if nsrec.has_key('badtill'):
1878                            if nsrec['badtill'] < curtime:
1879                                del nsrec['badtill']
1880                            else:
1881                                badserver = 1
1882                        if badserver:
1883                            log(2,'BAD SERVER, not using ' + nsrec['nsdname'])
1884                        if self.cachedb.has_key(nsrec['nsdname']) and not badserver:
1885                            if self.cachedb[nsrec['nsdname']].has_key('A'):
1886                                for arec in self.cachedb[nsrec['nsdname']]['A']:
1887                                    nsdict[nsrec['rtt']] = {'name':nsrec['nsdname'],
1888                                                            'ip':arec['address']}
1889                    if nsdict:
1890                        break
1891        if not nsdict:
1892            domainname = ''
1893            # nothing in the cache matches so give back the root servers
1894            for nsrec in self.cachedb['']['NS']:
1895                badserver = 0
1896                if nsrec.has_key('badtill'):
1897                    if curtime > nsrec['badtill']:
1898                        del nsrec['badtill']
1899                    else:
1900                        badserver = 1
1901                if not badserver:
1902                    for arec in self.cachedb[nsrec['nsdname']]['A']:
1903                        nsdict[(nsrec['rtt'])] = {'name':nsrec['nsdname'],'ip':arec['address']}
1904
1905        return (domainname, nsdict)
1906
1907    def badns(self, zonename, nsdname):
1908        if self.cachedb.has_key(zonename):
1909            if self.cachedb[zonename].has_key('NS'):
1910                for nsrec in self.cachedb[zonename]['NS']:
1911                    if nsrec['nsdname'] == nsdname:
1912                        log(2,'Setting ' + nsdname + ' as bad nameserver')
1913                        nsrec['badtill'] = time.time() + 3600
1914       
1915
1916    def updatertt(self, qname, zone, rtt):
1917        if self.cachedb.has_key(zone):
1918            if self.cachedb[zone].has_key('NS'):
1919                for rr in self.cachedb[zone]['NS']:
1920                    if rr['nsdname'] == qname:
1921                        log(2,'updating rtt for ' + qname + ' to ' + str(rtt))
1922                        rr['rtt'] = rtt
1923
1924    def reap(self):
1925        # expire all old records
1926        ntime = time.time()
1927        for nodename in self.cachedb.keys():
1928            for rrtype in self.cachedb[nodename].keys():
1929                for rdata in self.cachedb[nodename][rrtype]:
1930                    ttl = rdata['ttl']
1931                    if ttl != 0:
1932                        if ttl < ntime:
1933                            self.cachedb[nodename][rrtype].remove(rdata)
1934                if len(self.cachedb[nodename][rrtype]) == 0:
1935                    del self.cachedb[nodename][rrtype]
1936            if len(self.cachedb[nodename]) == 0:
1937                del self.cachedb[nodename]
1938                       
1939        return
1940
1941    def zonetrans(self, queryid):
1942        # build a list of messages
1943        # each message contains one rr of the zone
1944        # the first and last message are the
1945        # SOA records
1946        zonedata = self.cachedb
1947        rrlist = []
1948        soa = {'':{'SOA':[zonedata['']['SOA'][0]]}}
1949        for nodename in zonedata.keys():
1950            for rrtype in zonedata[nodename].keys():
1951                if not (rrtype == 'SOA' and nodename == ''):
1952                    for rr in zonedata[nodename][rrtype]:
1953                        rrlist.append({nodename:{rrtype:[rr]}})
1954        rrlist.insert(0,soa)
1955        rrlist.append(soa)
1956        msglist = []
1957        for rr in rrlist:
1958            msg = message()
1959            msg.header.id = queryid
1960            msg.header.qr = 1
1961            msg.header.aa = 1
1962            msg.header.rd = 0
1963            msg.header.qdcount = 1
1964            msg.question.qname = 'cache'
1965            msg.question.qtype = 'AXFR'
1966            msg.question.qclass = 'IN'
1967            msg.header.ancount = 1
1968            msg.answerlist.append(rr)
1969            msglist.append(msg)
1970        return msglist
1971
1972class gethostaddr(asyncore.dispatcher):
1973    def __init__(self, hostname, cbfunc, serveraddr='127.0.0.1'):
1974        asyncore.dispatcher.__init__(self)
1975        self.msg = message()
1976        self.msg.question.qname = hostname
1977        self.msg.question.qtype = 'A'
1978        self.cbfunc = cbfunc
1979        self.create_socket(socket.AF_INET, socket.SOCK_DGRAM)
1980        self.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 200 * 1024)
1981        self.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 200 * 1024)
1982        self.socket.sendto(self.msg.buildpkt(), (serveraddr,53))
1983
1984    def handle_read(self):
1985        replydata, addr = self.socket.recvfrom(1500)
1986        self.close()
1987        try:
1988            replymsg = message(replydata)
1989        except:
1990            log(0,'unable to process packet')
1991            return
1992        answername = replymsg.question.qname
1993        cname = ''
1994        # go through twice to catch cnames after A recs
1995        for rr in replymsg.answerlist:
1996            rrname = rr.keys()[0]
1997            rrtype = rr[rrname].keys()[0]
1998            dbrec = rr[rrname][rrtype][0]
1999            if rrname == answername and rrtype == 'CNAME':
2000                answername = dbrec['cname']
2001                cname = answername
2002        for rr in replymsg.answerlist:
2003            rrname = rr.keys()[0]
2004            rrtype = rr[rrname].keys()[0]
2005            dbrec = rr[rrname][rrtype][0]
2006            if rrname == answername and rrtype == 'A':
2007                self.cbfunc(dbrec['address'])
2008                return
2009        # if we got a cname and no A send query for cname
2010        if cname:
2011            self.msg = message()
2012            self.msg.question.qname = cname
2013            self.msg.question.qtype = 'A'
2014            self.socket.sendto(self.msg.buildpkt(), (serveraddr,53))
2015        else:
2016            self.cbfunc('')
2017
2018    def writable(self):
2019        return 0
2020
2021    def handle_write(self):
2022        pass
2023
2024    def handle_connect(self):
2025        pass
2026
2027    def handle_close(self):
2028        self.close()
2029
2030    def log_info (self, message, type='info'):
2031        if __debug__ or type != 'info':
2032            log(0,'%s: %s' % (type, message))
2033
2034class simpleudprequest(asyncore.dispatcher):
2035    def __init__(self, msg, cbfunc, serveraddr='127.0.0.1', outqkey=''):
2036        asyncore.dispatcher.__init__(self)
2037        self.gotanswer = 0
2038        self.msg = msg
2039        self.cbfunc = cbfunc
2040        self.create_socket(socket.AF_INET, socket.SOCK_DGRAM)
2041        self.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 200 * 1024)
2042        self.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 200 * 1024)
2043        self.outqkey = outqkey
2044        self.socket.sendto(self.msg.buildpkt(), (serveraddr,53))
2045
2046    def handle_read(self):
2047        replydata, addr = self.socket.recvfrom(1500)
2048        self.close()
2049        try:
2050            replymsg = message(replydata)
2051        except:
2052            log(0,'unable to process packet')
2053            return
2054        self.cbfunc(replymsg, self.outqkey)
2055
2056    def writable(self):
2057        return 0
2058
2059    def handle_write(self):
2060        pass
2061
2062    def handle_connect(self):
2063        pass
2064
2065    def handle_close(self):
2066        self.close()
2067
2068    def log_info (self, message, type='info'):
2069        if __debug__ or type != 'info':
2070            log(0,'%s: %s' % (type, message))
2071
2072class simpletcprequest(asyncore.dispatcher):
2073    def __init__(self, msg, cbfunc, cbparams=[], serveraddr='127.0.0.1', errorfunc=''):
2074        asyncore.dispatcher.__init__(self)
2075        self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
2076        self.query = msg
2077        self.cbfunc = cbfunc
2078        self.cbparams = cbparams
2079        self.errorfunc = errorfunc
2080        msgdata = msg.buildpkt()
2081        ml = inttoasc(len(msgdata))
2082        if len(ml) == 1:
2083            ml = chr(0) + ml
2084        self.buffer = ml+msgdata
2085        self.rbuffer = ''
2086        self.rmsgleft = 0
2087        self.rrlist = []
2088        log(2,'sending tcp request to ' + serveraddr)
2089        self.connect((serveraddr,53))
2090
2091    def recv (self, buffer_size):
2092        try:
2093            data = self.socket.recv (buffer_size)
2094            if not data:
2095                # a closed connection is indicated by signaling
2096                # a read condition, and having recv() return 0.
2097                self.handle_close()
2098                return ''
2099            else:
2100                return data
2101        except socket.error, why:
2102            # winsock sometimes throws ENOTCONN
2103            if why[0] in [ECONNRESET, ENOTCONN, ESHUTDOWN, ETIMEDOUT]:
2104                self.handle_close()
2105                return ''
2106            else:
2107                raise socket.error, why
2108
2109    def handle_connect(self):
2110        pass
2111
2112    def handle_msg(self, msg):
2113        if self.query.question.qtype == 'AXFR':
2114            if len(self.rrlist) == 0:
2115                if len(msg.answerlist) == 0:
2116                    if self.errorfunc:
2117                        self.errorfunc(self.cbparams[0])
2118                    self.close()
2119                    return
2120            rr = msg.answerlist[0]
2121            rrname = rr.keys()[0]
2122            rrtype = rr[rrname].keys()[0]
2123            self.rrlist.append(rr)
2124            if rrtype == 'SOA' and len(self.rrlist) > 1:
2125                self.close()
2126                if self.cbparams:
2127                    self.cbfunc(self.rrlist, self.cbparams)
2128                else:
2129                    self.cbfunc(self.rrlist)
2130        else:
2131            self.close()
2132            if self.cbparams:
2133                self.cbfunc(msg, self.cbparams)
2134            else:
2135                self.cbfunc(msg)
2136
2137    def handle_read(self):
2138        data = self.recv(8192)
2139        if len(self.rbuffer) == 0:
2140            self.rmsglength = asctoint(data[:2])
2141            data = data[2:]
2142        self.rbuffer = self.rbuffer + data
2143        while len(self.rbuffer) >= self.rmsglength and self.rmsglength != 0:
2144            msgdata = self.rbuffer[:self.rmsglength]
2145            self.rbuffer = self.rbuffer[self.rmsglength:]
2146            if len(self.rbuffer) == 0:
2147                self.rmsglength = 0
2148            else:
2149                self.rmsglength = asctoint(self.rbuffer[:2])
2150                self.rbuffer = self.rbuffer[2:]
2151            try:
2152                self.handle_msg(message(msgdata))
2153            except:
2154                return
2155           
2156    def writable(self):
2157        return (len(self.buffer) > 0)
2158   
2159    def handle_write(self):
2160        sent = self.send(self.buffer)
2161        self.buffer = self.buffer[sent:]
2162
2163    def handle_close(self):
2164        if self.errorfunc:
2165            self.errorfunc(self.query.question.qname)
2166        self.close()
2167
2168    def log_info (self, message, type='info'):
2169        if __debug__ or type != 'info':
2170            log(0,'%s: %s' % (type, message))
2171
2172class udpdnsserver(asyncore.dispatcher):
2173    def __init__(self, port, dnsserver):
2174        asyncore.dispatcher.__init__(self)
2175        self.create_socket(socket.AF_INET, socket.SOCK_DGRAM)
2176        self.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 200 * 1024)
2177        self.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 200 * 1024)   
2178        self.bind(('',port))
2179        self.dnsserver = dnsserver
2180        self.maxmsgsize = 500
2181
2182    def handle_read(self):
2183        try:
2184            while 1:
2185                msgdata, addr = self.socket.recvfrom(1500)
2186                self.dnsserver.handle_packet(msgdata, addr, self)
2187        except socket.error, why:
2188            if why[0] != asyncore.EWOULDBLOCK:
2189                raise socket.error, why
2190
2191    def sendpackets(self, msglist, addr):
2192        for msg in msglist:
2193            msgdata = msg.buildpkt()
2194            if len(msgdata) > self.maxmsgsize:
2195                msg.header.tc = 1
2196                # take off all the answers to ensure
2197                # the packet size is small enough
2198                msg.header.ancount = 0
2199                msg.header.nscount = 0
2200                msg.header.arcount = 0
2201                msg.answerlist = []
2202                msg.authlist = []
2203                msg.addlist = []
2204                msgdata = msg.buildpkt()
2205            self.sendto(msgdata, addr)
2206       
2207    def writable(self):
2208        return 0
2209
2210    def handle_write(self):
2211        pass
2212
2213    def handle_connect(self):
2214        pass
2215
2216    def handle_close(self):
2217        # print '1:In handle close'
2218        return
2219
2220    def log_info (self, message, type='info'):
2221        if __debug__ or type != 'info':
2222            log(0,'%s: %s' % (type, message))
2223
2224class tcpdnschannel(asynchat.async_chat):
2225    def __init__(self, server, s, addr):
2226        asynchat.async_chat.__init__(self, s)
2227        self.server = server
2228        self.addr = addr
2229        self.set_terminator(None)
2230        self.databuffer = ''
2231        self.msglength = 0
2232        log(3,'Created new tcp channel')
2233
2234    def collect_incoming_data(self, data):
2235        if self.msglength == 0:
2236            self.msglength = asctoint(data[:2])
2237            data = data[2:]
2238        self.databuffer = self.databuffer + data
2239        if len(self.databuffer) == self.msglength:
2240            # got entire message
2241            self.server.dnsserver.handle_packet(self.databuffer, self.addr, self)
2242            self.databuffer = ''
2243           
2244    def sendpackets(self, msglist, addr):
2245        for msg in msglist:
2246            x = msg.buildpkt()
2247            ml = inttoasc(len(x))
2248            if len(ml) == 1:
2249                ml = chr(0) + ml
2250            self.push(ml+x)
2251        self.close()
2252
2253    def log_info (self, message, type='info'):
2254        if __debug__ or type != 'info':
2255            log(0,'%s: %s' % (type, message))
2256
2257class tcpdnsserver(asyncore.dispatcher):
2258    def __init__(self, port, dnsserver):
2259        asyncore.dispatcher.__init__(self)
2260        self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
2261        self.set_reuse_addr()
2262        self.bind(('',port))
2263        self.listen(5)
2264        self.dnsserver = dnsserver
2265
2266    def handle_accept(self):
2267        conn, addr = self.accept()
2268        tcpdnschannel(self, conn, addr)
2269
2270    def handle_close(self):
2271        self.close()
2272
2273    def log_info (self, message, type='info'):
2274        if __debug__ or type != 'info':
2275            log(0,'%s: %s' % (type, message))
2276
2277class nameserver:
2278    def __init__(self, resolver, localconfig):
2279        self.resolver = resolver
2280        self.config = localconfig
2281        self.zdb = self.config.zonedatabase
2282        self.last_reap_time = time.time()
2283        self.maint_int = 10
2284        self.slavesupdating = []
2285        self.notifys = []
2286        self.sentnotify = []
2287        self.notify_retry_time = 30
2288        self.notify_retries = 4
2289        self.askedsoa = {}
2290        self.soatimeout = 10
2291
2292    def error(self, id, qname, querytype, queryclass, rcode):
2293        error = message()
2294        error.header.id = id
2295        error.header.rcode = rcode
2296        error.header.qr = 1
2297        error.question.qname = qname
2298        error.question.qtype = querytype
2299        error.question.qclass = queryclass
2300        return error
2301
2302    def need_zonetransfer(self, zkey, origin, masterip, trynum=0):
2303        self.askedsoa[zkey] = {'masterip':masterip,
2304                               'senttime':time.time(),
2305                               'origin':origin,
2306                               'trynum':trynum+1}
2307        query = message()
2308        query.header.id = random.randrange(1,32768)
2309        query.header.rd = 0
2310        query.question.qname = origin
2311        query.question.qtype = 'SOA'
2312        query.question.qclass = 'IN'
2313        log(3,'slave checking for new data in ' + origin)
2314        simpleudprequest(query, self.handle_soaquery,
2315                         masterip, zkey)
2316
2317    def handle_soaquery(self, msg, zkey):
2318        origin = msg.question.qname
2319        masterip = self.askedsoa[zkey]['masterip']
2320        del self.askedsoa[zkey]
2321        if zkey not in self.slavesupdating:
2322            self.slavesupdating.append(zkey)
2323            query = message()
2324            query.header.id = random.randrange(1,32768)
2325            query.header.rd = 0
2326            query.question.qname = origin
2327            query.question.qtype = 'AXFR'
2328            query.question.qclass = 'IN'
2329            log(3,'Updating slave zone: ' + zkey)
2330            simpletcprequest(query, self.handle_zonetrans,
2331                             [zkey],masterip,self.handle_zterror)
2332
2333    def handle_zonetrans(self, rrlist, params):
2334        log(1,'handling zone transfer')
2335        zonekey = params[0]
2336        self.zdb.update_zone(rrlist, params)
2337        self.slavesupdating.remove(zonekey)
2338
2339    def handle_zterror(self, zonekey):
2340        self.slavesupdating.remove(zonekey)
2341        self.zdb.remove_zone(zonekey)
2342
2343    def rrmatch(self, rrset1, rrset2):
2344        for rrtype in rrset1.keys():
2345            if rrtype not in rrset2.keys():
2346                return
2347            else:
2348                if len(rrset1[rrtype]) != len(rrset2[rrtype]):
2349                    return
2350        return 1
2351
2352    def process_notify(self, msg, ipaddr, port):
2353        (zkeys, dorecursion, flist) = self.config.getview(msg, ipaddr, port)
2354        goodzkey = ''
2355        for zkey in zkeys:
2356            origin = self.zdb.getorigin(zkey)
2357            if origin == msg.question.qname:
2358                masterip = self.zdb.getmasterip(zkey)
2359                if masterip:
2360                    goodzkey = zkey
2361        if goodzkey:
2362            log(3,'got NOTIFY from ' + masterip)
2363            self.need_zonetransfer(goodzkey, origin, masterip, 0)
2364        return
2365
2366    def notify(self):
2367        curtime = time.time()
2368        for origin, ipaddr, trynum, senttime in self.sentnotify:
2369            if senttime + self.notify_retry_time > curtime:
2370                self.notifys.append((origin, ipaddr, trynum))
2371                self.sentnotify.remove((origin, ipaddr, trynum, senttime))
2372        for origin, ipaddr, trynum in self.notifys:
2373            msg = message()
2374            msg.question.qname = origin
2375            msg.question.qtype = 'SOA'
2376            msg.question.qclass = 'IN'
2377            msg.header.opcode = 4
2378            # there probably is a better way to do this
2379            if self.resolver:
2380                self.resolver.send_to([msg],(ipaddr,53))
2381                if trynum+1 <= self.notify_retries:
2382                    self.sentnotify.append((origin,ipaddr,trynum+1,curtime))
2383        self.notifys = []
2384       
2385    def handle_packet(self, msgdata, addr, server):
2386        # self.reap()
2387        try:
2388            msg = message(msgdata)
2389        except:
2390            return
2391        # find a matching view
2392        (zkeys, dorecursion, flist) = self.config.getview(msg, addr[0], addr[1])
2393        if not msg.header.qr and msg.header.opcode == 5:
2394            log(2,'GOT UPDATE PACKET')
2395            # check the zone section
2396            if (msg.header.zocount != 1 or
2397                msg.zone.ztype != 'SOA' or
2398                msg.zone.zclass != 'IN'):
2399                log(2,'SENDING FORMERR UPDATE ERROR')
2400                errormsg = self.error(msg.header.id, msg.zone.zname,
2401                                  msg.zone.ztype, msg.zone.zclass, 1)
2402                server.sendpackets([errormsg],addr)
2403            else:
2404                (answer, origin, slaves) = self.zdb.handle_update(msg, addr, self)
2405                if answer.header.rcode == 0:
2406                    # schedule NOTIFYs to slaves
2407                    for ipaddr in slaves:
2408                        self.notifys.append((origin, ipaddr, 0))
2409                server.sendpackets([answer],addr)
2410        elif msg.header.opcode == 4:
2411            if msg.header.qr:
2412                log(0,'got NOTIFY response')
2413                for origin, ipaddr, trynum, senttime in self.sentnotify:
2414                    if ipaddr == addr[0] and msg.question.qname == origin:
2415                        self.sentnotify.remove((origin, ipaddr, trynum, senttime))
2416            else:
2417                log(0,'got NOTIFY')
2418                self.process_notify(msg, addr[0], addr[1])
2419        elif not msg.header.qr and msg.header.opcode == 0:
2420            # it's a question
2421            qname = msg.question.qname.lower()
2422            log(2,'GOT QUERY for ' + qname + '(' + msg.question.qtype +
2423                ') from ' + addr[0])
2424            # handle special version packet
2425            if (msg.question.qtype == 'TXT' and
2426                msg.question.qclass == 'CH'):
2427                if qname == 'version.bind':                   
2428                    server.sendpackets([getversion(qname,
2429                                                  msg.header.id,
2430                                                  msg.header.rd,
2431                                                  dorecursion, '1.0')],addr)
2432                elif qname == 'version.oak':
2433                    server.sendpackets([getversion(qname,
2434                                                  msg.header.id,
2435                                                  msg.header.rd,
2436                                                  dorecursion, '1.0')],addr)
2437                return
2438            self.zdb.lookup(zkeys, msg, addr, server, dorecursion,
2439                               flist, self.lookup_callback)
2440
2441    def lookup_callback(self, msg, addr, server, dorecursion, flist, answerlist):
2442        if answerlist:
2443            server.sendpackets(self.config.outpackets(answerlist), addr)
2444        elif dorecursion:
2445            if msg.question.qtype in ['AXFR','IXFR']:
2446                if msg.question.qname == 'cache' and msg.question.qtype == 'AXFR':
2447                    if self.resolver:
2448                        server.sendpackets(self.resolver.cache.zonetrans(msg.header.id),addr)
2449                else:
2450                    # won't forward zone transfers and
2451                    # don't handle recursive zone transfers
2452                    server.sendpackets([self.error(msg.header.id, msg.question.qname,
2453                                                   msg.question.qtype,
2454                                                   msg.question.qclass,2)],addr)
2455            else:
2456                self.resolver.handle_query(msg, addr, flist, server.sendpackets)
2457                             
2458    def reap(self):
2459        log(4,'in nameserver reap')
2460        # do all maintenence (interval) stuff here
2461        if self.resolver:
2462            self.resolver.reap()
2463        self.notify()
2464        curtime = time.time()
2465        if curtime > (self.last_reap_time + self.maint_int):
2466            self.last_reap_time = curtime
2467            # do zone transfers here if slave server and haven't asked for soa
2468            for (zkey, origin, masterip) in self.zdb.getslaves(curtime):
2469                if not self.askedsoa.has_key(zkey):
2470                    self.need_zonetransfer(zkey, origin, masterip)
2471        for zkey in self.askedsoa.keys():
2472            if curtime > self.askedsoa[zkey]['senttime'] + self.soatimeout:
2473                if self.askedsoa[zkey]['trynum'] > 3:
2474                    self.zdb.remove_zone(zkey)
2475                    del self.askedsoa[zkey]                   
2476                else:
2477                    masterip = self.askedsoa[zkey]['masterip']
2478                    origin = self.askedsoa[zkey]['origin']
2479                    trynum = self.askedsoa[zkey]['trynum']
2480                    del self.askedsoa[zkey]
2481                    self.need_zonetransfer(zkey, origin, masterip, trynum)
2482               
2483    def log_info (self, message, type='info'):
2484        if __debug__ or type != 'info':
2485            log(0,'%s: %s' % (type, message))
2486
2487class resolver(asyncore.dispatcher):
2488    def __init__(self, cache, port=0):
2489        asyncore.dispatcher.__init__(self)
2490        self.create_socket(socket.AF_INET, socket.SOCK_DGRAM)
2491        self.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 200 * 1024)
2492        self.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 200 * 1024)   
2493        self.bind(('',port))
2494        self.cache = cache
2495        self.outqnum = 0
2496        self.outq = {}
2497        self.holdq = {}
2498        self.holdtime = 10
2499        self.holdqlength = 100
2500        self.last_reap_time = time.time()
2501        self.maint_int = 10
2502        self.timeout = 3
2503
2504    def getoutqkey(self):
2505        self.outqnum = self.outqnum + 1
2506        if self.outqnum == 99999:
2507            self.outqnum = 1
2508        return str(self.outqnum)
2509
2510    def error(self, id, qname, querytype, queryclass, rcode):
2511        error = message()
2512        error.header.id = id
2513        error.header.rcode = rcode
2514        error.header.qr = 1
2515        error.question.qname = qname
2516        error.question.qtype = querytype
2517        error.question.qclass = queryclass
2518        return error
2519
2520    def qpacket(self, id, qname, querytype, queryclass):
2521        # create a question
2522        query = message()
2523        query.header.id = id
2524        query.header.rd = 0
2525        query.question.qname = qname
2526        query.question.qtype = querytype
2527        query.question.qclass = queryclass
2528        return query
2529
2530    def send_to(self, msglist, addr):
2531        for msg in msglist:
2532            data = msg.buildpkt()
2533            if len(data) > 512:
2534                # packet to big
2535                msg.header.tc = 1
2536                msg.header.ancount = 0
2537                msg.answerlist = []
2538                msg.header.nscount = 0
2539                msg.authlist = []
2540                msg.header.arcount = 0
2541                msg.addlist = []
2542                self.socket.sendto(msg.buildpkt(), addr)
2543            else:
2544                self.socket.sendto(data, addr)
2545
2546    def handle_read(self):
2547        try:
2548            while 1:
2549                msgdata, addr = self.socket.recvfrom(1500)
2550                # should put 'try' here in production server
2551                self.handle_packet(msgdata, addr)
2552        except socket.error, why:
2553            if why[0] != asyncore.EWOULDBLOCK:
2554                raise socket.error, why
2555
2556    def handle_packet(self, msgdata, addr):
2557        try:
2558            msg = message(msgdata)
2559        except:
2560            return
2561        if not msg.header.qr:
2562            self.handle_query(msg, addr, [], self.send_to)
2563        else:
2564            log(2,'received unsolicited reply')
2565
2566
2567    def handle_query(self, msg, addr, flist, cbfunc):
2568        qname = msg.question.qname
2569        querytype = msg.question.qtype
2570        queryclass = msg.question.qclass
2571        # check the cache first
2572        answer = self.cache.haskey(qname,querytype,msg)
2573        if answer:
2574            cbfunc([answer], addr)
2575            log(2,'sent answer for ' + qname + '(' + querytype +
2576                ') from cache')
2577        else:
2578            # check if query is already in progess
2579            for oqkey in self.outq.keys():
2580                if (self.outq[oqkey]['qname'] == qname and
2581                    self.outq[oqkey]['querytype'] == querytype):
2582                    log(2,'query already in progress for '+qname+'('+querytype+')')
2583                    # put entry in hold queue to try later
2584                    hqrec = {'processtime':time.time()+self.holdtime,
2585                             'query':msg,'addr':addr,
2586                             'qname':qname,'querytype':querytype,
2587                             'queryclass':queryclass,
2588                             'cbfunc':cbfunc}
2589                    self.putonhold(hqrec)
2590                    return
2591               
2592            outqkey = self.getoutqkey()+str(msg.header.id)               
2593            self.outq[outqkey] = {'query':msg,
2594                                  'addr':addr,
2595                                  'qname':qname,
2596                                  'querytype':querytype,
2597                                  'queryclass':queryclass,
2598                                  'cbfunc':cbfunc,
2599                                  'answerlist':[],
2600                                  'addlist':[],
2601                                  'qsent':0}
2602            if flist:
2603                self.outq[outqkey]['flist'] = flist
2604                self.askfns(outqkey)
2605            else:
2606                self.askns(outqkey)
2607
2608    def putonhold(self,hqrec):
2609        hqid = hqrec['qname']+hqrec['querytype']
2610        if self.holdq.has_key(hqid):
2611            if len(self.holdq[hqid]) < self.holdqlength:
2612                hqrec['processtime']=time.time()+self.holdtime
2613                self.holdq[hqid].append(hqrec)
2614       
2615           
2616    def askns(self, outqkey):
2617        qname = self.outq[outqkey]['qname']
2618        querytype = self.outq[outqkey]['querytype']
2619        queryclass = self.outq[outqkey]['queryclass']
2620        # don't try more than 10 times to avoid loops
2621        if self.outq[outqkey]['qsent'] == 10:
2622            del self.outq[outqkey]
2623            log(2,'Dropping query for ' + qname + '(' + querytype + ')' +
2624                   ' POSSIBLE LOOP')
2625            return
2626        # find the best nameservers to ask from the cache
2627        (qzone, nsdict) = self.cache.getnslist(qname)
2628        if not nsdict:
2629            # there are no good servers
2630            if self.outq[outqkey]['addr'] != 'IQ':
2631                qid = self.outq[outqkey]['query'].header.id
2632                self.outq[outqkey]['cbfunc'](self.error(qid,qname,querytype,queryclass,2),
2633                                   self.outq[outqkey]['addr'])
2634            del self.outq[outqkey]
2635            log(2,'Dropping query for ' + qname + '(' + querytype + ')' +
2636                   'no good name servers to ask')
2637            return
2638        # pick the best nameserver
2639        rtts = nsdict.keys()
2640        rtts.sort()
2641        bestnsip = nsdict[rtts[0]]['ip']
2642        bestnsname = nsdict[rtts[0]]['name']
2643        # fill in the callback data structure
2644        id=random.randrange(1,32768)
2645        self.outq[outqkey]['nsqueriedlastip'] = bestnsip
2646        self.outq[outqkey]['nsqueriedlastname'] = bestnsname
2647        self.outq[outqkey]['nsdict'] = nsdict
2648        self.outq[outqkey]['qzone'] = qzone
2649        self.outq[outqkey]['qsenttime'] = time.time()
2650        self.outq[outqkey]['qsent'] = self.outq[outqkey]['qsent'] + 1
2651        # self.socket.sendto(self.qpacket(id,qname,querytype,queryclass), (bestnsip,53))
2652        self.outq[outqkey]['request'] = simpleudprequest(self.qpacket(id,qname,querytype,queryclass),
2653                                                      self.handle_response, bestnsip, outqkey)
2654        # update rtt so that we ask a different server next time
2655        self.cache.updatertt(bestnsname,qzone,1)
2656        log(2,outqkey+'|sent query to ' + bestnsip + '(' + bestnsname +
2657               ') for ' + qname + '(' + querytype + ')')
2658
2659    def askfns(self, outqkey):
2660        flist = self.outq[outqkey]['flist']
2661        qname = self.outq[outqkey]['qname']
2662        querytype = self.outq[outqkey]['querytype']
2663        queryclass = self.outq[outqkey]['queryclass']
2664        self.outq[outqkey]['qsenttime'] = time.time()       
2665        id=random.randrange(1,32768)
2666        # self.socket.sendto(self.qpacket(id,qname,querytype,queryclass), (flist[0],53))
2667        self.outq[outqkey]['request'] = simpleudprequest(self.qpacket(id,qname,querytype,queryclass),
2668                                                         self.handle_fresponse, flist[0], outqkey)
2669        log(2,''+outqkey+'|sent query to forwarder')
2670
2671    def handle_response(self, msg, outqkey):
2672        # either reponse:
2673        # 1. contains a name error
2674        # 2. answers the question
2675        #    (cache data and return it)
2676        # 3. is (contains) a CNAME and qtype isn't
2677        #    (cache cname and change qname to it)
2678        #    (check if qname and qtype are in any other rrs in the response)
2679        #    (must check cache again here)
2680        # 4. contains a better delegation
2681        #    (cache the delegation and start again)
2682        # 5. is aserver failure
2683        #    (delete server from list and try again)
2684
2685        # make sure that original question is still outstanding
2686        if not self.outq.has_key(outqkey):
2687            # should never get here
2688            # if we do we aren't doing housekeeping of callbacks very well
2689            log(2,''+outqkey+'|got response for a question already answered for ' + msg.question.qname)
2690            return
2691
2692        querytype = self.outq[outqkey]['querytype']
2693        if msg.header.rcode not in [1,2,4,5]:       
2694            # update rtt time
2695            rtt = time.time() - self.outq[outqkey]['qsenttime']
2696            nsname = self.outq[outqkey]['nsqueriedlastname']
2697            zone = self.outq[outqkey]['qzone']
2698            self.cache.updatertt(nsname,zone,rtt)
2699
2700        if msg.header.rcode == 3:
2701            log(2,outqkey+'|GOT Name Error for ' + msg.question.qname +
2702                '(' + msg.question.qtype + ')')
2703            # name error
2704            # cache negative answer
2705            self.cache.addneg(self.outq[outqkey]['qname'],
2706                              self.outq[outqkey]['querytype'],
2707                              self.outq[outqkey]['queryclass'])
2708            if self.outq[outqkey]['addr'] != 'IQ':
2709                answer = message()               
2710                answer.question.qname = self.outq[outqkey]['query'].question.qname
2711                answer.question.qtype = self.outq[outqkey]['query'].question.qtype
2712                answer.question.qclass = self.outq[outqkey]['query'].question.qclass
2713                answer.header.id = self.outq[outqkey]['query'].header.id
2714                answer.header.qr = 1
2715                answer.header.opcode = self.outq[outqkey]['query'].header.opcode
2716                answer.header.ra = 1
2717                self.outq[outqkey]['cbfunc']([answer], self.outq[outqkey]['addr'])
2718            del self.outq[outqkey]
2719           
2720        elif msg.header.ancount > 0:
2721            # answer (may be CNAME)
2722            haveanswer = 0
2723            cname = ''
2724            log(2,'CACHING ANSWERLIST ENTRIES')
2725            for rr in msg.answerlist:
2726                rrname = rr.keys()[0]
2727                rrtype = rr[rrname].keys()[0]
2728                if ((rrname == msg.question.qname or rrname == cname ) and
2729                    rrtype == msg.question.qtype):
2730                    haveanswer = 1
2731                if rrname == msg.question.qname and rrtype == 'CNAME':
2732                    cname = rr[rrname][rrtype][0]['cname']
2733                self.cache.add(rr, self.outq[outqkey]['qzone'],
2734                               self.outq[outqkey]['nsqueriedlastname'])
2735            if haveanswer:
2736                if self.outq[outqkey]['addr'] != 'IQ':
2737                    log(2,''+outqkey+'|GOT Answer for ' + msg.question.qname +
2738                        '(' + msg.question.qtype + ')' )
2739                    answer = message()
2740                    answer.answerlist = msg.answerlist + self.outq[outqkey]['answerlist']
2741                    answer.header.ancount = len(answer.answerlist)
2742                    answer.question.qname = self.outq[outqkey]['query'].question.qname
2743                    answer.question.qtype = self.outq[outqkey]['query'].question.qtype
2744                    answer.question.qclass = self.outq[outqkey]['query'].question.qclass
2745                    answer.header.id = self.outq[outqkey]['query'].header.id
2746                    answer.header.qr = 1
2747                    answer.header.opcode = self.outq[outqkey]['query'].header.opcode
2748                    answer.header.ra = 1
2749                    self.outq[outqkey]['cbfunc']([answer], self.outq[outqkey]['addr'])
2750                    log(2,outqkey+'|sent answer retrieved from remote server for ' +
2751                           self.outq[outqkey]['query'].question.qname)
2752                else:
2753                    log(2,outqkey+'|GOT Answer(IQ) for ' + msg.question.qname + '(' +
2754                        msg.question.qtype + ')')
2755                del self.outq[outqkey]
2756            elif cname:
2757                log(2,outqkey+'|GOT CNAME for ' + msg.question.qname + '(' + msg.question.qtype + ')')
2758                self.outq[outqkey]['answerlist'] = self.outq[outqkey]['answerlist'] + msg.answerlist
2759                self.outq[outqkey]['qname'] = cname
2760                self.askns(outqkey)
2761            else:
2762                log(2,outqkey+'|GOT BOGUS answer for '  + msg.question.qname + '(' +
2763                    msg.question.qtype + ')')
2764                del self.outq[outqkey]
2765           
2766        elif msg.header.nscount > 0 and msg.header.ancount == 0:
2767            log(2,outqkey+'|GOT DELEGATION for ' + msg.question.qname + '(' + msg.question.qtype + ')')
2768            # delegation
2769            # cache the nameserver rrs and start over
2770            # if there are no glue records for nameservers must fetch them first
2771            log(2,'CACHING AUTHLIST ENTRIES')
2772            for rr in msg.authlist:
2773                self.cache.add(rr,self.outq[outqkey]['qzone'],self.outq[outqkey]['nsqueriedlastname'])
2774            log(2,'CACHING ADDLIST ENTRIES')
2775            for rr in msg.addlist:
2776                self.cache.add(rr,self.outq[outqkey]['qzone'],self.outq[outqkey]['nsqueriedlastname'])
2777            rrlist = msg.authlist+msg.addlist
2778            fetchglue = 0
2779            nscount = 0
2780            for rr in msg.authlist:
2781                nodename = rr.keys()[0]
2782                if rr[nodename].keys()[0] == 'NS':
2783                    nscount = nscount + 1
2784                    nsdname = rr[nodename]['NS'][0]['nsdname']
2785                    if not self.cache.haskey(nsdname,'A'):
2786                        log(2,outqkey+'|Glue record not in cache for ' + nsdname + '(A)')
2787                        fetchglue = fetchglue + 1
2788                        # need to fetch A rec
2789                        noutqkey = self.getoutqkey()+str(random.randrange(1,32768))
2790                        self.outq[noutqkey] = {'query':'',
2791                                               'addr':'IQ',
2792                                               'qname':nsdname,
2793                                               'querytype':'A',
2794                                               'queryclass':'IN',
2795                                               'qsent':0}
2796                        log(2,outqkey+'|sending a query to fetch glue records for ' + nsdname + '(A)')
2797                        self.askns(noutqkey)
2798            if not nscount:
2799                log(2,outqkey+'|Dropping query (no ns recs) for ' +
2800                       msg.question.qname + '(' + msg.question.qtype + ')' )
2801                del self.outq[outqkey]
2802            elif fetchglue == nscount:
2803                log(2,outqkey+'|Stalling query (no glue recs) for ' +
2804                       msg.question.qname + '(' + msg.question.qtype + ')')
2805                self.putonhold(self.outq[outqkey])
2806                del self.outq[outqkey]               
2807            else:
2808                log(2,outqkey+'|got (some) glue with delegation')
2809                self.askns(outqkey)
2810
2811        elif msg.header.rcode in [1,2,4,5]:
2812            log(2,outqkey+'|GOT ' + getrcode(msg.header.rcode))
2813            log(2,'SERVER ' + self.outq[outqkey]['nsqueriedlastname'] + '(' + 
2814                 self.outq[outqkey]['nsqueriedlastip'] + ') FAILURE for ' + msg.question.qname)
2815            # don't ask this server for a while
2816            self.cache.badns(self.outq[outqkey]['qzone'],self.outq[outqkey]['nsqueriedlastname'])
2817            self.askns(outqkey)
2818        else:
2819            log(2,outqkey+'|GOT UNPARSEABLE REPLY')
2820            msg.printpkt()
2821
2822    def handle_fresponse(self, msg, outqkey):
2823        if msg.header.rcode in [1,2,4,5]:
2824            self.outq[outqkey]['flist'].pop(0)
2825            if len(self.outq[outqkey]['flist']) == 0:
2826                qid = self.outq[outqkey]['query'].header.id
2827                qname = self.outq[outqkey]['qname']
2828                querytype = self.outq[outqkey]['querytype']
2829                queryclass = self.outq[outqkey]['queryclass']
2830                self.outq[outqkey]['cbfunc'](self.error(qid,qname,querytype,queryclass,2),
2831                                             self.outq[outqkey]['addr'])
2832                del self.outq[outqkey]
2833            else:
2834                self.askfns(outqkey)
2835        else:
2836            answer = message()
2837            answer.header.id = self.outq[outqkey]['query'].header.id
2838            answer.header.qr = 1
2839            answer.header.opcode = self.outq[outqkey]['query'].header.opcode
2840            answer.header.ra = 1
2841            answer.question.qname = self.outq[outqkey]['query'].question.qname
2842            answer.question.qtype = self.outq[outqkey]['query'].question.qtype
2843            answer.question.qclass = self.outq[outqkey]['query'].question.qclass
2844            answer.header.ancount = msg.header.ancount
2845            answer.header.nscount = msg.header.nscount
2846            answer.header.arcount = msg.header.arcount                       
2847            answer.answerlist = msg.answerlist
2848            answer.authlist = msg.authlist
2849            answer.addlist = msg.addlist               
2850            if msg.header.rcode == 3:
2851                # name error
2852                # cache negative answer
2853                self.cache.addneg(self.outq[outqkey]['qname'],
2854                                  self.outq[outqkey]['querytype'],
2855                                  self.outq[outqkey]['queryclass'])
2856            else:
2857                # cache all rrs
2858                for rr in msg.answerlist:
2859                    self.cache.add(rr,'','forwarder')
2860                for rr in msg.authlist:
2861                    self.cache.add(rr,'','forwarder')
2862                for rr in msg.addlist:
2863                    self.cache.add(rr,'','forwarder')
2864            self.outq[outqkey]['cbfunc']([answer], self.outq[outqkey]['addr'])
2865            del self.outq[outqkey]
2866
2867    def writable(self):
2868        return 0
2869
2870    def handle_write(self):
2871        pass
2872
2873    def handle_connect(self):
2874        pass
2875
2876    def handle_close(self):
2877        # print '1:In handle close'
2878        return
2879
2880    def process_holdq(self):
2881        curtime = time.time()
2882        for hqkey in self.holdq.keys():
2883            for hqrec in self.holdq[hqkey]:
2884                if curtime >= hqrec['processtime']:
2885                    log(2,'processing held query')
2886                    answer = self.cache.haskey(hqrec['qname'],
2887                                               hqrec['querytype'],
2888                                               hqrec['query'])
2889                    if answer:
2890                        hqrec['cbfunc']([answer], hqrec['addr'])
2891                        log(2,'sent answer for ' + hqrec['qname'] +
2892                            '(' + hqrec['querytype'] +  ') from cache')
2893                    self.holdq[hqkey].remove(hqrec)
2894            if len(self.holdq[hqkey]) == 0:
2895                del self.holdq[hqkey]
2896
2897    def reap(self):
2898        self.process_holdq()
2899        curtime = time.time()
2900        log(3,timestamp() + 'processed HOLDQ (sockets: ' +
2901            str(len(asyncore.socket_map.keys()))+')')
2902        if curtime > (self.last_reap_time + self.maint_int):
2903            self.last_reap_time = curtime
2904            for outqkey in self.outq.keys():
2905                if curtime > self.outq[outqkey]['qsenttime'] + self.timeout:
2906                    log(2,'query for '+self.outq[outqkey]['qname']+'('+
2907                        self.outq[outqkey]['querytype']+') expired')
2908                    # don't set forwarders as bad
2909                    if not self.outq[outqkey].has_key('flist'):
2910                        self.cache.badns(self.outq[outqkey]['qzone'],
2911                                         self.outq[outqkey]['nsqueriedlastname'])
2912                    if self.outq[outqkey].has_key('request'):
2913                        log(3,'closing socket for expired query')
2914                        self.outq[outqkey]['request'].close()
2915                    del self.outq[outqkey]
2916        return
2917
2918    def log_info (self, message, type='info'):
2919        if __debug__ or type != 'info':
2920            log(0,'%s: %s' % (type, message))
2921
2922
2923def run(configobj):
2924    global loglevel
2925    r = resolver(dnscache(configobj.cached))
2926    ns = nameserver(r, configobj)
2927    udpds = udpdnsserver(53, ns)
2928    tcpds = tcpdnsserver(53, ns)
2929    loglevel = configobj.loglevel
2930    try:
2931        loop(ns.reap)
2932    except KeyboardInterrupt:
2933        print 'server done'
2934
2935if __name__ == '__main__':
2936    sipb_xen_database.connect('postgres://sipb-xen@sipb-xen-dev/sipb_xen')
2937    zonedict = {'example.net':{'origin':'example.net',
2938                               'filename':'db.example.net',
2939                               'type':'master',
2940                               'slaves':[]}}
2941
2942
2943    zonedict = {'servers.csail.mit.edu':{'origin':'servers.csail.mit.edu',
2944                                         'filename':'db.servers.csail.mit.edu',
2945                                         'type':'master',
2946                                         'slaves':[]}}
2947
2948    zonedict2 = {'example.net':{'origin':'example.net',
2949                                'filename':'db.example.net',
2950                                'type':'slave',
2951                                'masterip':'127.0.0.1'}}
2952    readzonefiles(zonedict)
2953    lconfig = dnsconfig()
2954    lconfig.zonedatabase = zonedb(zonedict)
2955    pr = zonefileparser()
2956    pr.parse('','db.ca')
2957    lconfig.cached = pr.getzdict()
2958    lconfig.loglevel = 3
2959
2960    run(lconfig)
Note: See TracBrowser for help on using the repository browser.