#!/usr/bin/python
#    Python Domain Name Server
#    Copyright (C) 2002  Digital Lumber, Inc.

#    This library is free software; you can redistribute it and/or
#    modify it under the terms of the GNU Lesser General Public
#    License as published by the Free Software Foundation; either
#    version 2.1 of the License, or (at your option) any later version.

#    This library is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
#    Lesser General Public License for more details.

#    You should have received a copy of the GNU Lesser General Public
#    License along with this library; if not, write to the Free Software
#    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA

import socket
import asyncore
import asynchat
import select
import types
import random
import time
import signal
import string
import sys
import sipb_xen_database
from errno import EALREADY, EINPROGRESS, EWOULDBLOCK, ECONNRESET, \
     ENOTCONN, ESHUTDOWN, EINTR, EISCONN, ETIMEDOUT

# EXAMPLE ZONE FILE DATA STRUCTURE

# NOTE:
# There are no trailing dots in the internal data
# structure.  Although it's hard to tell by reading
# the RFC's, the dots on the end of names are just
# used internally by the resolvers and servers to
# see if they need to append a domain name onto
# the end of names.  There are no trailing dots
# on names in queries on the network.

examplenet = {'example.net':{'SOA':[{'class':'IN',
                                     'ttl':10,
                                     'mname':'ns1.example.net',
                                     'rname':'hostmaster.example.net',
                                     'serial':1,
                                     'refresh':10800,
                                     'retry':3600,
                                     'expire':604800,
                                     'minimum':3600}],
                             'NS':[{'class':'IN',
                                    'ttl':10,
                                    'nsdname':'ns1.example.net'},
                                   {'ttl':10,
                                    'nsdname':'ns2.example.net'}],
                             'MX':[{'class':'IN',
                                    'ttl':10,
                                    'preference':10,
                                    'exchange':'mail.example.net'}]},
              'server1.example.net':{'A':[{'class':'IN',
                                           'ttl':10,
                                           'address':'10.1.2.3'}]},
              'www.example.net':{'CNAME':[{'class':'IN',
                                           'ttl':10,
                                           'cname':'server1.example.net'}]},
              'router.example.net':{'A':[{'class':'IN',
                                          'ttl':10,
                                          'address':'10.1.2.1'},
                                         {'class':'IN',
                                          'ttl':10,
                                          'address':'10.2.1.1'}]}
              
              }

# setup logging defaults
loglevel = 0
logfile = sys.stdout

try:
    file
except NameError:
    def file(name, mode='r', buffer=0):
        return open(name, mode, buffer)

def log(level,msg):
    if level <= loglevel:
        logfile.write(msg+'\n')

def timestamp():
    return time.strftime('%m/%d/%y %H:%M:%S')+ '-'

def inttoasc(number):
    try:
        hs = hex(number)[2:]
    except:
        log(0,'inttoasc cannot convert ' + repr(number))
    if hs[-1:].upper() == 'L':
        hs = hs[:-1]
    result = ''
    while len(hs) > 2:
        result = chr(int(hs[-2:],16)) + result
        hs = hs[:-2]
    result = chr(int(hs,16)) + result
    
    return result
    
def asctoint(ascnum):
    rascnum = ''
    for i in range(len(ascnum)-1,-1,-1):
        rascnum = rascnum + ascnum[i]
    result = 0
    count = 0
    for c in rascnum:
        x = ord(c) << (8*count)
        result = result + x
        count = count + 1

    return result

def ipv6net_aton(ip_string):
    packed_ip = ''
    # first account for shorthand syntax
    pieces = ip_string.split(':')
    pcount = 0
    for part in pieces:
        if part != '':
            pcount = pcount + 1
    if pcount < 8:
        rs = '0:'*(8-pcount)
        ip_string = ip_string.replace('::',':'+rs)
    if ip_string[0] == ':':
        ip_string = ip_string[1:]
    pieces = ip_string.split(':')
    for part in pieces:
        # pad with the zeros
        i = 4-len(part)
        part = i*'0'+part
        packed_ip = packed_ip +  chr(int(part[:2],16))+ chr(int(part[2:],16))
    return packed_ip

def ipv6net_ntoa(packed_ip):
    ip_string = ''
    count = 0
    for c in packed_ip:
        ip_string = ip_string + hex(ord(c))[2:]
        count = count + 1
        if count == 2:
            ip_string = ip_string + ':'
            count = 0
    return ip_string[:-1]    

def getversion(qname, id, rd, ra, versionstr):
    msg = message()
    msg.header.id = id
    msg.header.qr = 1
    msg.header.aa = 1
    msg.header.rd = rd
    msg.header.ra = ra
    msg.header.rcode = 0
    msg.question.qname = qname
    msg.question.qtype = 'TXT'
    msg.question.qclass = 'CH'
    if qname == 'version.bind':
        msg.header.ancount = 2
        msg.answerlist.append({qname:{'CNAME':[{'cname':'version.oak',
                                                'ttl':360000,
                                                'class':'CH'}]}})
        msg.answerlist.append({'version.oak':{'TXT':[{'txtdata':versionstr,
                                                      'ttl':360000,
                                                      'class':'CH'}]}})
    else:
        msg.header.ancount = 1
        msg.answerlist.append({qname:{'TXT':[{'txtdata':versionstr,
                                              'ttl':360000,
                                              'class':'CH'}]}})
    return msg

def getrcode(rcode):
    if rcode == 0:
        rcodestr = 'NOERROR(No error condition)'
    elif rcode == 1:
        rcodestr = 'FORMERR(Format Error)'
    elif rcode == 2:
        rcodestr = 'SERVFAIL(Internal failure)'
    elif rcode == 3:
        rcodestr = 'NXDOMAIN(Name does not exist)'
    elif rcode == 4:
        rcodestr = 'NOTIMP(Not Implemented)'
    elif rcode == 5:
        rcodestr = 'REFUSED(Security violation)'
    elif rcode == 6:
        rcodestr = 'YXDOMAIN(Name exists)'
    elif rcode == 7:
        rcodestr = 'YXRRSET(RR exists)'
    elif rcode == 8:
        rcodestr = 'NXRRSET(RR does not exist)'
    elif rcode == 9:
        rcodestr = 'NOTAUTH(Server not Authoritative)'
    elif rcode == 10:
        rcodestr = 'NOTZONE(Name not in zone)'
    else:
        rcodestr = 'Unknown RCODE(' + str(rcode) + ')'
    return rcodestr

def printrdata(dnstype, rdata):
    if dnstype == 'A':
        return rdata['address']
    elif dnstype == 'MX':
        return str(rdata['preference'])+'\t'+rdata['exchange']+'.'
    elif dnstype == 'NS':
        return rdata['nsdname']+'.'
    elif dnstype == 'PTR':
        return rdata['ptrdname']+'.'
    elif dnstype == 'CNAME':
        return rdata['cname']+'.'
    elif dnstype == 'SOA':
        return (rdata['mname']+'.\t'+rdata['rname']+'. (\n'+35*' '+str(rdata['serial'])+'\n'+
                35*' '+str(rdata['refresh'])+'\n'+35*' '+str(rdata['retry'])+'\n'+35*' '+
                str(rdata['expire'])+'\n'+35*' '+str(rdata['minimum'])+' )')

def makezonedatalist(zonedata, origin):
    # unravel structure into list
    zonedatalist = []
    # get soa first
    soanode = zonedata[origin]
    zonedatalist.append([origin+'.','SOA',soanode['SOA'][0]])
    for item in soanode.keys():
        if item != 'SOA':
            for listitem in soanode[item]:
                zonedatalist.append([origin+'.', item, listitem])
    for nodename in zonedata.keys():
        if nodename != origin:
            for item in zonedata[nodename].keys():
                for listitem in zonedata[nodename][item]:
                    zonedatalist.append([nodename+'.', item, listitem])
    return zonedatalist

def writezonefile(zonedata, origin, file):
    zonedatalist = makezonedatalist(zonedata, origin)
    for rr in zonedatalist:
        owner = rr[0]
        dnstype = rr[1]
        line = (owner + (35-len(owner))*' ' + str(rr[2]['ttl']) + '\t\tIN\t' +
                dnstype + '\t' + printrdata(dnstype, rr[2]))
        file.write(line + '\n')

def readzonefiles(zonedict):
    for k in zonedict.keys():
        filepath = zonedict[k]['filename']
        try:
            pr = zonefileparser()
            pr.parse(zonedict[k]['origin'],filepath)
            zonedict[k]['zonedata'] = pr.getzdict()
        except ZonefileError, lineno:
            log(0,'Error reading zone file ' + filepath  + ' at line ' +
                str(lineno) + '\n')
            del zonedict[k]

def slowloop(tofunc='',timeout=5.0):
    if not tofunc:
        def tofunc(): return
    map = asyncore.socket_map
    while map:
        r = []; w=[]; e=[]
        for fd, obj in map.items():
            if obj.readable():
                r.append(fd)
            if obj.writable():
                w.append(fd)
        try:
            starttime = time.time()            
            r,w,e = select.select(r,w,e,timeout)
            endtime = time.time()
            if endtime-starttime >= timeout:
                tofunc()
        except select.error, err:
            if err[0] != EINTR:
                raise
            r=[]; w=[]; e=[]
            log(0,'ERROR in select')

        for fd in r:
            try:
                obj=map[fd]
            except KeyError:
                log(0,'KeyError in socket map')                
                continue
            try:
                obj.handle_read_event()
            except:
                log(0,'calling HANDLE ERROR from loop')
                log(0,repr(obj))
                obj.handle_error()
        for fd in w:
            try:
                obj=map[fd]
            except KeyError:
                log(0,'KeyError in socket map')                
                continue
            try:
                obj.handle_read_event()
            except:
                log(0,'calling HANDLE ERROR from loop')
                log(0,repr(obj))                
                obj.handle_error()

def fastloop(tofunc='',timeout=5.0):
    if not tofunc:
        def tofunc(): return
    polltimeout = timeout*1000
    map = asyncore.socket_map
    while map:
        regfds = 0
        pollobj = select.poll()
        for fd, obj in map.items():
            flags = 0
            if obj.readable():
                flags = select.POLLIN
            if obj.writable():
                flags = flags | select.POLLOUT
            if flags:
                pollobj.register(fd, flags)
                regfds = regfds + 1
        try:
            starttime = time.time()
            r = pollobj.poll(polltimeout)
            endtime = time.time()
            if endtime-starttime >= timeout:
                tofunc()
        except select.error, err:
            if err[0] != EINTR:
                raise
            r = []
            log(0,'ERROR in select')
        for fd, flags in r:
            try:
                obj = map[fd]
                badvals = (select.POLLPRI + select.POLLERR +
                           select.POLLHUP + select.POLLNVAL)
                if (flags & badvals):
                    if (flags & select.POLLPRI):
                        log(0,'POLLPRI')
                    if (flags & select.POLLERR):
                        log(0,'POLLERR')
                    if (flags & select.POLLHUP):
                        log(0,'POLLHUP')
                    if (flags & select.POLLNVAL):
                        log(0,'POLLNVAL')
                    obj.handle_error()
                else:
                    if (flags  & select.POLLIN):
                        obj.handle_read_event()
                    if (flags & select.POLLOUT):
                        obj.handle_write_event()
            except KeyError:
                log(0,'KeyError in socket map')
                continue
            except:
                # print traceback
                sf = StringIO.StringIO()
                traceback.print_exc(file=sf)
                log(0,'ERROR IN LOOP:')
                log(0,sf.getvalue())
                sf.close()
                log(0,repr(obj))
                obj.handle_error()

if hasattr(select,'poll'):
    loop = fastloop
else:
    loop = slowloop

class ZonefileError(Exception):
    def __init__(self, linenum, errordesc=''):
        self.linenum = linenum
        self.errordesc = errordesc
    def __str__(self):
        return str(self.linenum) + ' (' + self.errordesc + ')'

class zonefileparser:
    def __init__(self):
        self.zonedata = {}
        self.dnstypes = ['A','AAAA','CNAME','HINFO','LOC','MX',
                         'NS','PTR','RP','SOA','SRV','TXT']
        
    def stripcomments(self, line):
        i = line.find(';')
        if i >= 0:
            line = line[:i]
        return line

    def strip(self, line):
        # strip trailing linefeeds
        if line[-1:] == '\n':
            line = line[:-1]
        return line

    def getzdict(self):
        return self.zonedata

    def addorigin(self, origin, name):
        if name[-1:] != '.':
            return name + '.' + origin
        else:
            return name[:-1]

    def getstrings(self, s):
        if s.find('"') == -1:
            return s.split()
        else:
            x = s.split('"')
            rlist = []
            for i in x:
                if i != '' and i != ' ':
                    rlist.append(i)
            return rlist

    def getlocsize(self, s):
        if s[-1:] == 'm':
            size = float(s[:-1])*100
        else:
            size = float(s)*100
        i = 0
        while size > 9:
            size = size/10
            i = i + 1
        return (int(size),i)

    def getloclat(self, l,c):
        deg = float(l[0])
        min = 0
        secs = 0
        if len(l) == 3:
            min = float(l[1])
            secs = float(l[2])
        elif len(l) == 2:
            min = float(l[1])
        rval = ((((deg *60) + min) * 60) + secs) * 1000
        if c in ['N','E']:
            rval = rval + (2**31)
        elif c in ['S','W']:
            rval = (2**31) - rval
        else:
            log(0,'ERROR: unsupported latitude/longitude direction')
        return long(rval)

    def getgname(self, name, iter):
        if name == '0' or name == 'O':
            return ''
        start = 0
        offset = 0
        width = 0
        base = 'd'
        for x in range(name.count('$')):
            i = name.find('$',start)
            j = i
            start = i+1
            if i>0:
                if name[i-1] == '\\':
                    continue
            if len(name)>i+1:
                if name[i+1] == '$':
                    continue
                if name[i+1] == '{':
                    j = name.find('}',i+1)
                    owb = name[i+2:j].split(',')
                    if len(owb) == 1:
                        offset = int(owb[0])
                    elif len(owb) == 2:
                        offset = int(owb[0])
                        width = int(owb[1])
                    elif len(owb) == 3:
                        offset = int(owb[0])
                        width = int(owb[1])
                        base = owb[2]
            val = iter - offset
            if base == 'd':
                rs = str(val)
            elif base == 'o':
                rs = oct(val)
            elif base == 'x':
                rs = hex(val)[2:].lower()
            elif base == 'X':
                rs = hex(val)[2:].upper()
            else:
                rs = ''
            if len(rs) > width:
                rs = (width-len(rs))*'0'+rs
            name = name[:i]+rs+name[j+1:]
            start = i+len(rs)+1

        return name

    def getrrdata(self, origin, dnstype, dnsclass, ttl, tokens):
        rdata = {}
        rdata['class'] = dnsclass
        rdata['ttl'] = ttl
        if dnstype == 'A':
            rdata['address'] = tokens[0]
        elif dnstype == 'AAAA':
            rdata['address'] = tokens[0]
        elif dnstype == 'CNAME':
            rdata['cname'] = self.addorigin(origin,tokens[0].lower())
        elif dnstype == 'HINFO':
            sl = self.getstrings(' '.join(tokens))
            rdata['cpu'] = sl[0]
            rdata['os'] = sl[1]
        elif dnstype == 'LOC':
            if 'N' in tokens:
                i = tokens.index('N')
            else:
                i = tokens.index('S')
            lat = self.getloclat(tokens[0:i],tokens[i])            
            if 'E' in tokens:
                j = tokens.index('E')
            else:
                j = tokens.index('W')
            lng = self.getloclat(tokens[i+1:j],tokens[j])
            size = self.getlocsize('1m')
            horiz_pre = self.getlocsize('1000m')
            vert_pre = self.getlocsize('10m')
            if len(tokens[j+1:]) == 2:
                size = self.getlocsize(tokens[-1:][0])
            elif len(tokens[j+1:]) == 3:
                size = self.getlocsize(tokens[-2:-1][0])
                horiz_pre = self.getlocsize(tokens[-1:][0])
            elif len(tokens[j+1:]) == 4:
                size = self.getlocsize(tokens[-3:-2][0])
                horiz_pre = self.getlocsize(tokens[-2:-1][0])
                vert_pre = self.getlocsize(tokens[-1:][0])
            if tokens[j+1][-1:] == 'm':
                alt = int((float(tokens[j+1][:-1])*100)+10000000)
            else:
                size = int((float(tokens[j+1])*100)+10000000)
            rdata['version'] = 0
            rdata['size'] = size
            rdata['horiz_pre'] = horiz_pre
            rdata['vert_pre'] = vert_pre
            rdata['latitude'] = lat
            rdata['longitude'] = lng
            rdata['altitude'] = 0
        elif dnstype == 'MX':
            rdata['preference'] = int(tokens[0])
            rdata['exchange'] = self.addorigin(origin,tokens[1].lower())
        elif dnstype == 'NS':
            rdata['nsdname'] = self.addorigin(origin,tokens[0].lower())
        elif dnstype == 'PTR':
            rdata['ptrdname'] = self.addorigin(origin,tokens[0].lower())
        elif dnstype == 'RP':
            rdata['mboxdname'] = self.addorigin(origin,tokens[0].lower())
            rdata['txtdname'] = self.addorigin(origin,tokens[1].lower())
        elif dnstype == 'SOA':
            rdata['mname'] = self.addorigin(origin,tokens[0].lower())
            rdata['rname'] = self.addorigin(origin,tokens[1].lower())
            rdata['serial'] = int(tokens[2])
            rdata['refresh'] = int(tokens[3])
            rdata['retry'] = int(tokens[4])
            rdata['expire'] = int(tokens[5])
            rdata['minimum'] = int(tokens[6])
        elif dnstype == 'SRV':
            rdata['priority'] = int(tokens[0])
            rdata['weight'] = int(tokens[1])
            rdata['port'] = int(tokens[2])
            rdata['target'] = self.addorigin(origin,tokens[3].lower())
        elif dnstype == 'TXT':
            rdata['txtdata'] = self.getstrings(' '.join(tokens))[0]
        else:
            raise ZonefileError(lineno,'bad DNS type')            
        return rdata

    def addrec(self, owner, dnstype, rrdata):
        if self.zonedata.has_key(owner):
            if not self.zonedata[owner].has_key(dnstype):
                self.zonedata[owner][dnstype] = []
        else:
            self.zonedata[owner] = {}
            self.zonedata[owner][dnstype] = []
        self.zonedata[owner][dnstype].append(rrdata)

    def parse(self, origin, f):
        closefile = 0
        if type(f) != types.FileType:
            # must be a path
            try:
                f = file(f)
                closefile = 1
            except:
                log(0,'Invalid path to zonefile')
                return
        lastowner = ''
        lastdnsclass = ''
        lastttl = 3600
        lineno = 0
        while 1:
            line = f.readline()
            if not line:
                break
            lineno = lineno + 1
            line = self.stripcomments(line)
            line = self.strip(line)
            if not line:
                continue
            if line.find('(') >= 0:
                # grab lines until end paren
                if line.find(')') == -1:
                    line2 = self.stripcomments(f.readline())
                    lineno = lineno + 1
                    line2 = self.strip(line2)
                    line = line + line2
                    while line2.find(')') == -1:
                        line2 = self.strip(self.stripcomments(f.readline()))
                        lineno = lineno + 1
                        line = line + line2
                # now strip the parenthesis
                line = line.replace(')','')
                line = line.replace('(','')
            # now line equals the entire RR entry
            tokens = line.split()
            if tokens[0].upper() == '$ORIGIN':
                try:
                    origin = tokens[1].lower()
                except:
                    raise ZonefileError(lineno, 'bad origin')
            elif tokens[0].upper() == '$INCLUDE':
                try:
                    f2 = file(tokens[1].lower())
                    if len(tokens) > 2:
                        self.parse(tokens[2].lower(), f2)
                    else:
                        self.parse(origin, f2)
                    f2.close()
                except:
                    raise ZonefileError(lineno, 'bad INCLUDE directive')
            elif tokens[0].upper() == '$TTL':
                try:
                    lastttl = int(tokens[1])
                except:
                    raise ZonefileError(lineno, 'bad TTL directive')
            elif tokens[0].upper() == '$GENERATE':
                try:
                    lhs = tokens[2].lower()
                    dnstype = tokens[3].upper()
                    rhs = tokens[4].lower()
                    rng = tokens[1].split('-')                    
                    start = int(rng[0])
                    i = rng[1].find('/')
                    if i != -1:
                        stop = int(rng[1][:i])+1
                        step = int(rng[1][i+1:])
                    else:
                        stop = int(rng[1])+1
                        step = 1
                    for i in range(start,stop,step):
                        grhs = self.getgname(rhs,i)
                        if dnstype in ['NS','CNAME','PTR']:
                            grhs = self.addorigin(origin,grhs)
                        rrdata = self.getrrdata(origin, dnstype, 'IN', lastttl,
                                                [grhs])
                        glhs = self.addorigin(origin,self.getgname(lhs,i))
                        self.addrec(glhs,dnstype, rrdata)
                except KeyError:
                    raise ZonefileError(lineno, 'bad GENERATE directive')
            else:
                try:
                    # if line begins with blank then owner is last owner
                    if line[0] in string.whitespace:
                        owner = lastowner
                    else:
                        owner = tokens[0].lower()
                        tokens = tokens[1:]
                        if owner == '@':
                            owner = origin
                        elif owner[-1:] != '.':
                            owner = owner + '.' + origin
                        else:
                            owner = owner[:-1] # strip off trailing dot
                    # line format is either: [class] [ttl] type RDATA
                    #                     or [ttl] [class] type RDATA
                    # - items in brackets are optional
                    #
                    # need to figure out which token is type
                    # and backfill the missing data
                    count = 0
                    for token in tokens:
                        if token.upper() in self.dnstypes:
                            break
                        count = count + 1
                    # the following strips off the ttl and class if they exist
                    if count == 0:
                        ttl = lastttl
                        dnsclass = lastdnsclass
                    elif count == 1:
                        if tokens[0].isdigit():
                            ttl = int(tokens[0])
                            dnsclass = lastdnsclass
                        else:
                            ttl = lastttl
                            dnsclass = tokens[0].upper()
                        tokens = tokens[1:]
                    elif count == 2:
                        if tokens[0].isdigit():
                            ttl = int(tokens[0])
                            dnsclass = tokens[1].upper()
                        else:
                            ttl = int(tokens[1])
                            dnsclass = tokens[0].upper()
                        tokens = tokens[2:]
                    else:
                        raise ZonefileError(lineno,'bad ttl or class')
                    dnstype = tokens[0]
                    # make sure all of the structure is there
                    rrdata = self.getrrdata(origin, dnstype, dnsclass,
                                            ttl, tokens[1:])
                    self.addrec(owner, dnstype, rrdata)
                    lastowner = owner
                    lastttl = ttl
                    lastdnsclass = dnsclass
                except:
                    raise ZonefileError(lineno,'unable to parse line')
        if closefile:
            f.close()
        
class dnsconfig:
    def __init__(self):
        # self.zonedb = zonedb({})
        self.cached = {}
        self.loglevel = 0
        
    def getview(self, msg, address, port):
        # return:
        #  1. a list of zone keys
        #  2. whether or not to use the resolver
        #     (i.e. answer recursive queries)
        #  3. a list of forwarder addresses
        return ['servers.csail.mit.edu'], 1, []

    def allowupdate(self, msg, address, port):
        # return 1 if updates are allowed
        # NOTE: can only update the zones
        #       returned by the getview func
        return 1

    def outpackets(self, packetlist):
        # modify outgoing packets
        return packetlist

class dnsheader:
    def __init__(self, id=1):
        self.id = id # 16bit identifier generated by queryer
        self.qr = 0 # one bit field specifying query(0) or response(1)
        self.opcode = 0 # 4bit field specifying type of query
        self.aa = 0 # authoritative answer
        self.tc = 0 # message is not truncated
        self.rd = 1 # recursion desired
        self.ra = 0 # recursion available?
        self.z = 0 # reserved for future use
        self.rcode = 0 # response code (set in response)
        self.qdcount = 1 # number of questions, only 1 is supported
        self.ancount = 0 # number of rrs in the answer section
        self.nscount = 0 # number of name server rrs in authority section
        self.arcount = 0 # number or rrs in the additional section

class dnsquestion:
    def __init__(self):
        self.qname = 'localhost'
        self.qtype = 'A'
        self.qclass = 'IN'

class dnsupdatezone:
    pass

class message:
    def __init__(self, msgdata=''):
        if msgdata:
            self.header = dnsheader()
        else:
            self.header = dnsheader(id=random.randrange(1,32768))
        self.question = dnsquestion()
        self.answerlist = []
        self.authlist = []
        self.addlist = []
        self.u = ''
        self.qtypes = {1:'A',2:'NS',3:'MD',4:'MF',5:'CNAME',6:'SOA',
                       7:'MB',8:'MG',9:'MR',10:'NULL',11:'WKS',
                       12:'PTR',13:'HINFO',14:'MINFO',15:'MX',
                       16:'TXT',17:'RP',28:'AAAA',29:'LOC',33:'SRV',
                       38:'A6',39:'DNAME',251:'IXFR',252:'AXFR',
                       253:'MAILB',254:'MAILA',255:'ANY'}
        self.rqtypes = {}
        for key in self.qtypes.keys():
            self.rqtypes[self.qtypes[key]] = key
        self.qclasses = {1:'IN',2:'CS',3:'CH',4:'HS',254:'NONE',255:'ANY'}
        self.rqclasses = {}
        for key in self.qclasses.keys():
            self.rqclasses[self.qclasses[key]] = key

        if msgdata:
            self.processpkt(msgdata)

    def getdomainname(self, data, i):
        log(4,'IN GETDOMAINNAME')
        domainname = ''
        gotpointer = 0
        labellength= ord(data[i])
        log(4,'labellength:' + str(labellength))
        i = i + 1
        while labellength != 0:
            while labellength >= 192:
                # pointer
                if not gotpointer:
                    rindex = i + 1
                    gotpointer = 1
                    log(4,'got pointer')
                i = asctoint(chr(ord(data[i-1]) & 63)+data[i])
                log(4,'new index:'+str(i))
                labellength = ord(data[i])
                log(4,'labellength:' + str(labellength))
                i = i + 1
            if domainname:
                domainname = domainname + '.' + data[i:i+labellength]
            else:
                domainname = data[i:i+labellength]
            log(4,'domainname:'+domainname)
            i = i + labellength
            labellength = ord(data[i])
            log(4,'labellength:' + str(labellength))
            i = i + 1
        if not gotpointer:
            rindex = i

        return domainname.lower(), rindex

    def getrrdata(self, type, msgdata, rdlength, i):
        log(4,'unpacking RR data')
        rdata = msgdata[i:i+rdlength]
        if type == 'A':
            return {'address':socket.inet_ntoa(rdata)}
        elif type == 'AAAA':
            return {'address':ipv6net_ntoa(rdata)}
        elif type == 'CNAME':
            cname, i = self.getdomainname(msgdata,i)
            return {'cname':cname}
        elif type == 'HINFO':
            cpulen = ord(rdata[0])
            cpu = rdata[1:cpulen+1]
            return {'cpu':cpu,
                    'os':rdata[cpulen+2:]}
        elif type == 'LOC':
            return {'version':ord(rdata[0]),
                    'size':self.locsize(rdata[1]),
                    'horiz_pre':self.locsize(rdata[2]),
                    'vert_pre':self.locsize(rdata[3]),
                    'latitude':asctoint(rdata[4:8]),
                    'longitude':asctoint(rdata[8:12]),
                    'altitude':asctoint(rdata[12:16])}
        elif type == 'MX':
            exchange, i = self.getdomainname(msgdata,i+2)
            return {'preference':asctoint(rdata[:2]),
                    'exchange':exchange}
        elif type == 'NS':
            nsdname, i = self.getdomainname(msgdata,i)
            return {'nsdname':nsdname}
        elif type == 'PTR':
            ptrdname, i = self.getdomainname(msgdata,i)
            return {'ptrdname':ptrdname}
        elif type == 'RP':
            mboxdname, i = self.getdomainname(msgdata,i)
            txtdname, i = self.getdomainname(msgdata,i)
            return {'mboxdname':mboxdname,
                    'txtdname':txtdname}
        elif type == 'SOA':
            mname, i = self.getdomainname(msgdata,i)
            rname, i = self.getdomainname(msgdata,i)
            return {'mname':mname,
                    'rname':rname,
                    'serial':asctoint(msgdata[i:i+4]),
                    'refresh':asctoint(msgdata[i+4:i+8]),
                    'retry':asctoint(msgdata[i+8:i+12]),
                    'expire':asctoint(msgdata[i+12:i+16]),
                    'minimum':asctoint(msgdata[i+16:i+20])}
        elif type == 'SRV':
            target, i = self.getdomainname(msgdata,i+6)            
            return {'priority':asctoint(rdata[0:2]),
                    'weight':asctoint(rdata[2:4]),
                    'port':asctoint(rdata[4:6]),
                    'target':target}
        elif type == 'TXT':
            return {'txtdata':rdata[1:]}
        else:
            return {'rdata':rdata}
        
    def getrr(self, data, i):
        log(4,'unpacking RR name')
        name, i = self.getdomainname(data, i)
        type = asctoint(data[i:i+2])
        type = self.qtypes.get(type,chr(type))
        klass = asctoint(data[i+2:i+4])
        klass = self.qclasses.get(klass,chr(klass))
        ttl = asctoint(data[i+4:i+8])
        rdlength = asctoint(data[i+8:i+10])
        rrdata = self.getrrdata(type,data,rdlength,i+10)
        rrdata['ttl'] = ttl
        rrdata['class'] = klass
        rr = {name:{type:[rrdata]}}
        return rr, i+10+rdlength

    def processpkt(self, msgdata):
        self.header.id = asctoint(msgdata[:2])
        self.header.qr = ord(msgdata[2]) >> 7
        self.header.opcode = (ord(msgdata[2]) & 127) >> 3
        if self.header.opcode == 5:
            # UPDATE packet
            log(4,'processing UPDATE packet')
            del self.header.aa
            del self.header.tc
            del self.header.rd
            del self.header.ra
            del self.header.qdcount
            del self.header.ancount
            del self.header.nscount
            del self.header.arcount
            del self.question
            self.zone = dnsupdatezone()
            del self.answerlist
            del self.authlist
            del self.addlist
            self.header.z = 0
            self.header.rcode = ord(msgdata[3]) & 15
            self.header.zocount = asctoint(msgdata[4:6])
            self.header.prcount = asctoint(msgdata[6:8])
            self.header.upcount = asctoint(msgdata[8:10])
            self.header.arcount = asctoint(msgdata[10:12])
            self.zolist = []
            self.prlist = []
            self.uplist = []
            self.addlist = []
            i = 12
            for x in range(self.header.zocount):
                (dn, i) = self.getdomainname(msgdata,i)
                self.zone.zname = dn
                type = asctoint(msgdata[i:i+2])
                self.zone.ztype = self.qtypes.get(type,chr(type))
                klass = asctoint(msgdata[i+2:i+4])
                self.zone.zclass = self.qclasses.get(klass,chr(klass))
                i = i + 4
            for x in range(self.header.prcount):
                rr, i  = self.getrr(msgdata,i)
                self.prlist.append(rr)
            for x in range(self.header.upcount):
                rr, i  = self.getrr(msgdata,i)
                self.uplist.append(rr)
            for x in range(self.header.arcount):
                rr, i  = self.getrr(msgdata,i)
                self.adlist.append(rr)
        else:
            self.header.aa = (ord(msgdata[2]) & 4) >> 2
            self.header.tc = (ord(msgdata[2]) & 2) >> 1
            self.header.rd = ord(msgdata[2]) & 1
            self.header.ra = ord(msgdata[3]) >> 7
            self.header.z = (ord(msgdata[3]) & 112) >> 4
            self.header.rcode = ord(msgdata[3]) & 15
            self.header.qdcount = asctoint(msgdata[4:6])
            self.header.ancount = asctoint(msgdata[6:8])
            self.header.nscount = asctoint(msgdata[8:10])
            self.header.arcount = asctoint(msgdata[10:12])
            i = 12
            for x in range(self.header.qdcount):
                log(4,'unpacking question')
                (dn, i) = self.getdomainname(msgdata,i)
                self.question.qname = dn
                rrtype = asctoint(msgdata[i:i+2])
                self.question.qtype = self.qtypes.get(rrtype,chr(rrtype))
                klass = asctoint(msgdata[i+2:i+4])
                self.question.qclass = self.qclasses.get(klass,chr(klass))
                i = i + 4
            for x in range(self.header.ancount):
                log(4,'unpacking answer RR')
                rr, i = self.getrr(msgdata,i)
                self.answerlist.append(rr)
            for x in range(self.header.nscount):
                log(4,'unpacking auth RR')
                rr, i = self.getrr(msgdata,i)            
                self.authlist.append(rr)
            for x in range(self.header.arcount):
                log(4,'unpacking additional RR')
                rr, i = self.getrr(msgdata,i)            
                self.addlist.append(rr)
        return

    def pds(self, s, l):
        # pad string with chr(0)'s so that
        # return string length is l
        x = l - len(s)
        return x*chr(0) + s

    def locsize(self, s):
        x1 = ord(s) >> 4
        x2 = ord(s) & 15
        return (x1, x2)

    def packlocsize(self, x):
        return chr((x[0] << 4) + x[1])

    def packdomainname(self, name, i, msgcomp):
        log(4,'packing domainname: ' + name)
        if name == '':
            return chr(0)
        if name in msgcomp.keys():
            log(4,'using pointer for: ' + name)
            return msgcomp[name]
        packedname = ''
        tokens = name.split('.')
        for j in range(len(tokens)):
            packedname = packedname + chr(len(tokens[j])) + tokens[j]
            nameleft = '.'.join(tokens[j+1:])
            if nameleft in msgcomp.keys():
                log(4,'using pointer for: ' + nameleft)
                return packedname+msgcomp[nameleft]
        # haven't used a pointer so put this in the dictionary
        pointer = inttoasc(i)
        if len(pointer) == 1:
            msgcomp[name] = chr(192)+pointer
        else:
            msgcomp[name] = chr(192|ord(pointer[0])) + pointer[1]
        log(4,'added pointer for ' + name + '(' + str(i) + ')')
        return packedname + chr(0)

    def packrr(self, rr, i, msgcomp):
        rrname = rr.keys()[0]
        rrtype = rr[rrname].keys()[0]
        if self.rqtypes.has_key(rrtype):
            typeval = self.rqtypes[rrtype]
        else:
            typeval = ord(rrtype)
        dbrec = rr[rrname][rrtype][0]
        ttl = dbrec['ttl']
        rclass = self.rqclasses[dbrec['class']]
        packedrr = (self.packdomainname(rrname, i, msgcomp) +
                    self.pds(inttoasc(typeval),2) +
                    self.pds(inttoasc(rclass),2) +
                    self.pds(inttoasc(ttl),4))
        i = i + len(packedrr) + 2
        if rrtype == 'A':
            rdata = socket.inet_aton(dbrec['address'])
        elif rrtype == 'AAAA':
            rdata = ipv6net_aton(dbrec['address'])
        elif rrtype == 'CNAME':
            rdata = self.packdomainname(dbrec['cname'], i, msgcomp)
        elif rrtype == 'HINFO':
            rdata = (chr(len(dbrec['cpu'])) + dbrec['cpu'] +
                     chr(len(dbrec['os'])) + dbrec['os'])
        elif rrtype == 'LOC':
            rdata = (chr(dbrec['version']) +
                     self.packlocsize(dbrec['size']) +
                     self.packlocsize(dbrec['horiz_pre']) +
                     self.packlocsize(dbrec['vert_pre']) +
                     self.pds(inttoasc(dbrec['latitude']),4) +
                     self.pds(inttoasc(dbrec['longitude']),4) +
                     self.pds(inttoasc(dbrec['altitude']),4))
        elif rrtype == 'MX':
            rdata = (self.pds(inttoasc(dbrec['preference']),2) +
                     self.packdomainname(dbrec['exchange'], i+2, msgcomp))
        elif rrtype == 'NS':
            rdata = self.packdomainname(dbrec['nsdname'], i, msgcomp)
        elif rrtype == 'PTR':
            rdata = self.packdomainname(dbrec['ptrdname'], i, msgcomp)
        elif rrtype == 'RP':
            rdata1 = self.packdomainname(dbrec['mboxdname'], i , msgcomp)
            i = i + len(rdata1)
            rdata2 = self.packdomainname(dbrec['mboxdname'], i , msgcomp)
            rdata = rdata1 + rdata2
        elif rrtype == 'SOA':
            rdata1 = self.packdomainname(dbrec['mname'], i, msgcomp)
            i = i + len(rdata1)
            rdata2 = self.packdomainname(dbrec['rname'], i, msgcomp)
            rdata = (rdata1 +
                     rdata2 +
                     self.pds(inttoasc(dbrec['serial']),4) +
                     self.pds(inttoasc(dbrec['refresh']),4) +
                     self.pds(inttoasc(dbrec['retry']),4) +
                     self.pds(inttoasc(dbrec['expire']),4) +
                     self.pds(inttoasc(dbrec['minimum']),4))
        elif rrtype == 'SRV':
            rdata = (self.pds(inttoasc(dbrec['priority']),2) +
                     self.pds(inttoasc(dbrec['weight']),2) +
                     self.pds(inttoasc(dbrec['port']),2) +
                     self.packdomainname(dbrec['target'], i+6, msgcomp))
        elif rrtype == 'TXT':
            rdata = chr(len(dbrec['txtdata'])) + dbrec['txtdata']
        else:
            rdata = dbrec['rdata']

        return packedrr+self.pds(inttoasc(len(rdata)),2)+rdata

    def buildpkt(self):
        # keep dictionary of names packed (so we can use pointers)
        msgcomp = {}
        # header
        if self.header.id > 65535:
            log(0,'building packet with bad ID field')
            self.header.id = 1
        msgdata = inttoasc(self.header.id)
        if len(msgdata) == 1:
            msgdata = chr(0) + msgdata
        h1 = ((self.header.qr << 7) +
              (self.header.opcode << 3) +
              (self.header.aa << 2) +
              (self.header.tc << 1) +
              (self.header.rd))
        h2 = ((self.header.ra << 7) +
              (self.header.z << 4) +
              (self.header.rcode))
        msgdata = msgdata + chr(h1) + chr(h2)
        msgdata = msgdata + self.pds(inttoasc(self.header.qdcount),2)
        msgdata = msgdata + self.pds(inttoasc(self.header.ancount),2)
        msgdata = msgdata + self.pds(inttoasc(self.header.nscount),2)
        msgdata = msgdata + self.pds(inttoasc(self.header.arcount),2)
        # question
        msgdata = msgdata + self.packdomainname(self.question.qname, len(msgdata), msgcomp)
        if self.rqtypes.has_key(self.question.qtype):
            typeval = self.rqtypes[self.question.qtype]
        else:
            typeval = ord(self.question.qtype)
        msgdata = msgdata + self.pds(inttoasc(typeval),2)
        if self.rqclasses.has_key(self.question.qclass):
            classval = self.rqclasses[self.question.qclass]
        else:
            classval = ord(self.question.qclass)
        msgdata = msgdata + self.pds(inttoasc(classval),2)
        # rr's
        # RR record format:
        # {'name' : {'type' : [rdata, rdata, ...]}
        # example: {'test.blah.net': {'A': [{'address': '10.1.1.2',
        #                                    'ttl': 3600L}]}}
        for rr in self.answerlist:
            log(4,'packing answer RR')
            msgdata = msgdata + self.packrr(rr, len(msgdata), msgcomp)
        for rr in self.authlist:
            log(4,'packing auth RR')
            msgdata = msgdata + self.packrr(rr, len(msgdata), msgcomp)
        for rr in self.addlist:
            log(4,'packing additional RR')
            msgdata = msgdata + self.packrr(rr, len(msgdata), msgcomp)
            
        return msgdata

    def printpkt(self):
        print 'ID: ' +str(self.header.id)
        if self.header.qr:
            print 'QR: RESPONSE'
        else:
            print 'QR: QUERY'
        if self.header.opcode == 0:
            print 'OPCODE: STANDARD QUERY'
        elif self.header.opcode == 1:
            print 'OPCODE: INVERSE QUERY'
        elif self.header.opcode == 2:
            print 'OPCODE: SERVER STATUS REQUEST'
        elif self.header.opcode == 5:
            print 'UPDATE REQUEST'
        else:
            print 'OPCODE: UNKNOWN QUERY TYPE'
        if self.header.opcode != 5:
            if self.header.aa:
                print 'AA: AUTHORITATIVE ANSWER'
            else:
                print 'AA: NON-AUTHORITATIVE ANSWER'
            if self.header.tc:
                print 'TC: MESSAGE IS TRUNCATED'
            else:
                print 'TC: MESSAGE IS NOT TRUNCATED'
            if self.header.rd:
                print 'RD: RECURSION DESIRED'
            else:
                print 'RD: RECURSION NOT DESIRED'
            if self.header.ra:
                print 'RA: RECURSION AVAILABLE'
            else:
                print 'RA: RECURSION IS NOT AVAILABLE'
        if self.header.rcode == 1:
            printrcode =  'FORMERR'
        elif self.header.rcode == 2:
            printrcode =  'SERVFAIL'
        elif self.header.rcode == 3:
            printrcode =  'NXDOMAIN'
        elif self.header.rcode == 4:
            printrcode =  'NOTIMP'
        elif self.header.rcode == 5:
            printrcode =  'REFUSED'
        elif self.header.rcode == 6:
            printrcode = 'YXDOMAIN'
        elif self.header.rcode == 7:
            printrcode = 'YXRRSET'
        elif self.header.rcode == 8:
            printrcode = 'NXRRSET'
        elif self.header.rcode == 9:
            printrcode = 'NOTAUTH'
        elif self.header.rcode == 10:
            printrcode = 'NOTZONE'
        else:
            printrcode =  'NOERROR'
        print 'RCODE: ' + printrcode
        if self.header.opcode == 5:
            print 'NUMBER OF RRs in the Zone Section: ' + str(self.header.zocount)
            print 'NUMBER OF RRs in the Prerequisite Section: ' + str(self.header.prcount)
            print 'NUMBER OF RRs in the Update Section: ' + str(self.header.upcount)
            print 'NUMBER OF RRs in the Additional Data Section: ' + str(self.header.arcount)
            print 'ZONE SECTION:'
            print 'zname: ' + self.zone.zname
            print 'zonetype: ' + self.zone.ztype
            print 'zoneclass: ' + self.zone.zclass
            print 'PREREQUISITE RRs:'
            for rr in self.prlist:
                print rr
            print 'UPDATE RRs:'        
            for rr in self.uplist:
                print rr
            print 'ADDITIONAL RRs:'        
            for rr in self.addlist:
                print rr


        else:
            print 'NUMBER OF QUESTION RRs: ' + str(self.header.qdcount)
            print 'NUMBER OF ANSWER RRs: ' + str(self.header.ancount)
            print 'NUMBER OF NAME SERVER RRs: ' + str(self.header.nscount)
            print 'NUMBER OF ADDITIONAL RRs: ' + str(self.header.arcount)
            print 'QUESTION SECTION:'
            print 'qname: ' + self.question.qname
            print 'querytype: ' + self.question.qtype
            print 'queryclass: ' + self.question.qclass
            print 'ANSWER RRs:'
            for rr in self.answerlist:
                print rr
            print 'AUTHORITY RRs:'        
            for rr in self.authlist:
                print rr
            print 'ADDITIONAL RRs:'        
            for rr in self.addlist:
                print rr

class zonedb:
    def __init__(self, zdict):
        self.zdict = zdict
        self.updates = {}
        for k in self.zdict.keys():
            if self.zdict[k]['type'] == 'slave':
                self.zdict[k]['lastupdatetime'] = 0

    def error(self, id, qname, querytype, queryclass, rcode):
        error = message()
        error.header.id = id
        error.header.rcode = rcode
        error.header.qr = 1
        error.question.qname = qname
        error.question.qtype = querytype
        error.question.qclass = queryclass
        return error

    def getorigin(self, zkey):
        origin = ''
        if self.zdict.has_key(zkey):
            origin = self.zdict[zkey]['origin']
        return origin

    def getmasterip(self, zkey):
        masterip = ''
        if self.zdict.has_key(zkey):
            if self.zdict[zkey].has_key('masterip'):
                masterip = self.zdict[zkey]['masterip']
        return masterip

    def zonetrans(self, query):
        # build a list of messages
        # each message contains one rr of the zone
        # the first and last message are the
        # SOA records
        origin = query.question.qname
        querytype = query.question.qtype
        zkey = ''
        for zonekey in self.zdict.keys():
            if self.zdict[zonekey]['origin'] == query.question.qname:
                zkey = zonekey
        if not zkey:
            return []
        zonedata = self.zdict[zkey]['zonedata']
        queryid = query.header.id
        soarec = zonedata[origin]['SOA'][0]
        soa = {origin:{'SOA':[soarec]}}
        curserial = soarec['serial']
        rrlist = []
        if querytype == 'IXFR':
            clientserial = query.authlist[0][origin]['SOA'][0]['serial']
            if clientserial < curserial:
                for i in range(clientserial,curserial+1):
                    if self.updates[zkey].has_key(i):
                        for rr in self.updates[zkey][i]['added']:
                            rrlist.append(rr)
                        for rr in self.updates[zkey][i]['removed']:
                            rrlist.append(rr)
                if len(rrlist) > 0:
                    rrlist.insert(0,soa)
                rrlist.append(soa)
            else:
                rrlist.append(soa)
        else:
            for nodename in zonedata.keys():
                for rrtype in zonedata[nodename].keys():
                    if not (rrtype == 'SOA' and nodename == origin):
                        for rr in zonedata[nodename][rrtype]:
                            rrlist.append({nodename:{rrtype:[rr]}})
            rrlist.insert(0,soa)
            rrlist.append(soa)
        msglist = []
        for rr in rrlist:
            msg = message()
            msg.header.id = queryid
            msg.header.qr = 1
            msg.header.aa = 1
            msg.header.rd = 0
            msg.header.qdcount = 1
            msg.question.qname = origin
            msg.question.qtype = querytype
            msg.question.qclass = 'IN'
            msg.header.ancount = 1
            msg.answerlist.append(rr)
            msglist.append(msg)
        return msglist

    def update_zone(self, rrlist, params):
        zonekey = params[0]
        zonedata = {}
        soa = rrlist.pop()
        origin = soa.keys()[0]
        for rr in rrlist:
            rrname = rr.keys()[0]
            rrtype = rr[rrname].keys()[0]
            dbrec = rr[rrname][rrtype][0]
            if zonedata.has_key(rrname):
                if not zonedata[rrname].has_key(rrtype):
                    zonedata[rrname][rrtype] = []
            else:
                zonedata[rrname] = {}
                zonedata[rrname][rrtype] = []
            zonedata[rrname][rrtype].append(dbrec)
        self.zdict[zonekey]['zonedata'] = zonedata
        curtime = time.time()
        self.zdict[zonekey]['lastupdatetime'] = curtime
        try:
            f = file(self.zdict[zonekey]['filename'],'w')
            writezonefile(zonedata, self.zdict[zonekey]['origin'], f)
            f.close()
        except:
            log(0,'unable to write zone ' + zonekey + 'to disk')
        log(1,'finished zone transfer for: ' + zonekey + ' (' + str(curtime) + ')')

    def remove_zone(self, zonekey):
        if self.zdict.has_key(zonekey):
            del self.zdict[zonekey]

    def getslaves(self, curtime):
        rlist = []
        for k in self.zdict.keys():
            if self.zdict[k]['type'] == 'slave':
                origin = self.zdict[k]['origin']
                refresh = self.zdict[k]['zonedata'][origin]['SOA'][0]['refresh']
                if self.zdict[k]['lastupdatetime'] + refresh < curtime:
                    rlist.append((k, origin, self.zdict[k]['masterip']))
        return rlist

    def zmatch(self, qname, zkeys):
        for zkey in zkeys:
            if self.zdict.has_key(zkey):
                origin = self.zdict[zkey]['origin']
                if qname.rfind(origin) != -1:
                    return zkey
        return ''

    def getzlist(self, name, zone):
        if name == zone:
            return
        zlist = []
        i = name.rfind(zone)
        if i == -1:
            return
        firstpart = name[:i-1]
        partlist = firstpart.split('.')
        partlist.reverse()
        lastpart = zone
        for x in range(len(partlist)):
            lastpart = partlist[x] + '.' + lastpart
            zlist.append(lastpart)
        return zlist

    def lookup(self, zkeys, query, addr, server, dorecursion, flist, cbfunc):
        # handle zone transfers seperately
        qname = query.question.qname
        querytype = query.question.qtype
        queryclass = query.question.qclass
        if querytype in ['AXFR','IXFR']:
            for zkey in self.zdict.keys():
                if zkey in zkeys:
                    if qname == self.zdict[zkey]['origin']:
                        answerlist = self.zonetrans(query)
                        break
            else:
                answerlist = []
            cbfunc(query, addr, server, dorecursion, flist, answerlist)
        else:
            zonekey = self.zmatch(qname, zkeys)
            if zonekey:
                origin = self.zdict[zonekey]['origin']
                zonedict = self.zdict[zonekey]['zonedata']
                referral = 0
                rranswerlist = []
                rrnslist = []
                rraddlist = []
                answer = message()
                answer.header.aa = 1
                answer.header.id = query.header.id
                answer.header.qr = 1
                answer.header.opcode = query.header.opcode
                answer.header.rcode = 4
                answer.header.ra = dorecursion
                answer.question.qname = query.question.qname
                answer.question.qtype = query.question.qtype
                answer.question.qclass = query.question.qclass
                answer.header.ra = dorecursion
                s = '.servers.csail.mit.edu'
                if qname.endswith(s):
                    host = qname[:-len(s)]
                    value = sipb_xen_database.NIC.get_by(hostname=host)
                    if value is None:
                        pass
                    else:
                        ip = value.ip
                        rranswerlist.append({qname: {'A': [{'address': ip, 
                                                            'class': 'IN', 
                                                            'ttl': 10}]}})
                if zonedict.has_key(qname):
                    # found the node, now take care of CNAMEs
                    if zonedict[qname].has_key('CNAME'):
                        if querytype != 'CNAME':
                            nodetype = 'CNAME'
                            while nodetype == 'CNAME':
                                rranswerlist.append({qname:{'CNAME':[zonedict[qname]['CNAME'][0]]}})
                                qname = zonedict[qname]['CNAME'][0]['cname']
                                if zonedict.has_key(qname):
                                    nodetype = zonedict[qname].keys()[0]
                                else:
                                    # error, shouldn't have a CNAME that points to nothing
                                    return
                    # if we get this far, then the record has matched and we should return
                    # a reply that has no error (even if there is no info macthing the qtype)
                    answer.header.rcode = 0
                    answernode = zonedict[qname]
                    if querytype == 'ANY':
                        for type in answernode.keys():
                            for rec in answernode[type]:
                                rranswerlist.append({qname:{type:[rec]}})
                    elif answernode.has_key(querytype):
                        for rec in answernode[querytype]:
                            rranswerlist.append({qname:{querytype:[rec]}})
                        # do rrset ordering (cyclic)
                        if len(answernode[querytype]) > 1:
                            rec = answernode[querytype].pop(0)
                            answernode[querytype].append(rec)
                    else:
                        # remove all cname rrs from answerlist
                        rranswerlist = []
                else:
                    # would check for wildcards here (but aren't because they seem bad)
                    # see if we need to give a referral
                    zlist = self.getzlist(qname,origin)
                    for zonename in zlist:
                        if zonedict.has_key(zonename):
                            if zonedict[zonename].has_key('NS'):
                                answer.header.rcode = 0
                                referral = 1
                                for rec in zonedict[zonename]['NS']:
                                    rrnslist.append({zonename:{'NS':[rec]}})
                                    nsdname = rec['nsdname']
                                    # add glue records if they exist
                                    if zonedict.has_key(nsdname):
                                        if zonedict[nsdname].has_key('A'):
                                            for gluerec in zonedict[nsdname]['A']:
                                                rraddlist.append({nsdname:{'A':[gluerec]}})
                    # negative caching stuff
                    if not referral:
                        if not rranswerlist:
                            # NOTE: RFC1034 section 4.3.4 says we should add the SOA record
                            #       to the additional section of the response.  BIND adds
                            #       it to the ns section though
                            answer.header.rcode = 3
                            rrnslist.append({origin:{'SOA':[zonedict[origin]['SOA'][0]]}})
                        else:
                            for rec in zonedict[origin]['NS']:
                                rrnslist.append({origin:{'NS':[rec]}})
                answer.header.ancount = len(rranswerlist)
                answer.header.nscount = len(rrnslist)
                answer.header.arcount = len(rraddlist)
                answer.answerlist = rranswerlist
                answer.authlist = rrnslist
                answer.addlist = rraddlist
                cbfunc(query, addr, server, dorecursion, flist, [answer])
            else:
                cbfunc(query, addr, server, dorecursion, flist, [])

    def handle_update(self, msg, addr, ns):
        zkey = ''
        slaves = []
        for zonekey in self.zdict.keys():
            if (self.zdict[zonekey]['type'] == 'master' and
                self.zdict[zonekey]['origin'] == msg.zone.zname):
                zkey = zonekey
        if not zkey:
            log(2,'SENDING NOTAUTH UPDATE ERROR')
            errormsg = self.error(msg.header.id, msg.zone.zname,
                                  msg.zone.ztype, msg.zone.zclass, 9)
            return errormsg, '', slaves
        # find the slaves for the zone
        if self.zdict[zkey].has_key('slaves'):
            slaves = self.zdict[zkey]['slaves']
        origin = self.zdict[zkey]['origin']
        zd = self.zdict[zkey]['zonedata']
        # check the permissions
        if not ns.config.allowupdate(msg, addr[0], addr[1]):
            log(2,'SENDING REFUSED UPDATE ERROR')
            errormsg = self.error(msg.header.id, msg.zone.zname,
                                  msg.zone.ztype, msg.zone.zclass, 5)
            return errormsg, origin, slaves
        # now check the prereqs
        temprrset = {}
        for rr in msg.prlist:
            rrname = rr.keys()[0]
            rrtype = rr[rrname].keys()[0]
            dbrec = rr[rrname][rrtype][0]
            if dbrec['ttl'] != 0:
                log(2,'FORMERROR(1)')
                errormsg = self.error(msg.header.id, msg.zone.zname,
                                      msg.zone.ztype, msg.zone.zclass, 1)
                return errormsg, origin, slaves
            if rrname.rfind(msg.zone.zname) == -1:
                log(2,'NOTZONE(10)')
                errormsg = self.error(msg.header.id, msg.zone.zname,
                                      msg.zone.ztype, msg.zone.zclass, 10)
                return errormsg, origin, slaves
            if dbrec['class'] == 'ANY':
                if dbrec['rdata']:
                    log(2,'FORMERROR(1)')
                    errormsg = self.error(msg.header.id, msg.zone.zname,
                                          msg.zone.ztype, msg.zone.zclass, 1)
                    return errormsg, origin, slaves
                if rrtype == 'ANY':
                    if not zd.has_key(rrname):
                        log(2,'NXDOMAIN(3)')
                        errormsg = self.error(msg.header.id, msg.zone.zname,
                                              msg.zone.ztype, msg.zone.zclass, 3)
                        return errormsg, origin, slaves
                else:
                    rrsettest = 0
                    if zd.has_key(rrname):
                        if zd[rrname].has_key(rrtype):
                            rrsettest = 1
                    if not rrsettest:
                        log(2,'NXRRSET(8)')
                        errormsg = self.error(msg.header.id, msg.zone.zname,
                                              msg.zone.ztype, msg.zone.zclass, 8)
                        return errormsg, origin, slaves
            if dbrec['class'] == 'NONE':
                if dbrec['rdata']:
                    log(2,'FORMERROR(1)')
                    errormsg = self.error(msg.header.id, msg.zone.zname,
                                          msg.zone.ztype, msg.zone.zclass, 1)
                    return errormsg, origin, slaves
                if rrtype == 'ANY':
                    if zd.has_key(rrname):
                        log(2,'YXDOMAIN(6)')
                        errormsg = self.error(msg.header.id, msg.zone.zname,
                                              msg.zone.ztype, msg.zone.zclass, 6)
                        return errormsg, origin, slaves
                else:
                    if zd.has_key(rrname):
                        if zd[rrname].has_key(rrtype):
                            log(2,'YXRRSET(7)')
                            errormsg = self.error(msg.header.id, msg.zone.zname,
                                                  msg.zone.ztype, msg.zone.zclass, 7)
                            return errormsg, origin, slaves
            if dbrec['class'] == msg.zone.zclass:
                if temprrset.has_key(rrname):
                    if not temprrset[rrname].has_key(rrtype):
                        temprrset[rrname][rrtype] = []
                else:
                    temprrset[rrname] = {}
                    temprrset[rrname][rrtype] = []
                temprrset[rrname][rrtype].append(dbrec)
            else:
                log(2,'FORMERROR(1)')
                errormsg = self.error(msg.header.id, msg.zone.zname,
                                      msg.zone.ztype, msg.zone.zclass, 1)
                return errormsg, origin, slaves
        for nodename in temprrset.keys():
            if not self.rrmatch(temprrset[nodename],zd[nodename]):
                log(2,'NXRRSET(8)')
                errormsg = self.error(msg.header.id, msg.zone.zname,
                                      msg.zone.ztype, msg.zone.zclass, 8)
                return errormsg, origin, slaves

        # update section prescan
        for rr in msg.uplist:
            rrname = rr.keys()[0]
            rrtype = rr[rrname].keys()[0]
            dbrec = rr[rrname][rrtype][0]
            if rrname.rfind(msg.zone.zname) == -1:
                log(2,'NOTZONE(10)')
                errormsg = self.error(msg.header.id, msg.zone.zname,
                                      msg.zone.ztype, msg.zone.zclass, 10)
                return errormsg, origin, slaves
            if dbrec['class'] == msg.zone.zclass:
                if rrtype in ['ANY','MAILA','MAILB','AXFR']:
                    log(2,'FORMERROR(1)')
                    errormsg = self.error(msg.header.id, msg.zone.zname,
                                          msg.zone.ztype, msg.zone.zclass, 1)
                    return errormsg, origin, slaves
            elif dbrec['class'] == 'ANY':
                if dbrec['ttl'] != 0 or dbrec['rdata'] or rrtype in ['MAILA','MAILB','AXFR']:
                    log(2,'FORMERROR(1)')
                    errormsg = self.error(msg.header.id, msg.zone.zname,
                                          msg.zone.ztype, msg.zone.zclass, 1)
                    return errormsg, origin, slaves
            elif dbrec['class'] == 'NONE':
                if dbrec['ttl'] != 0 or rrtype in ['ANY','MAILA','MAILB','AXFR']:
                    log(2,'FORMERROR(1)')
                    errormsg = self.error(msg.header.id, msg.zone.zname,
                                          msg.zone.ztype, msg.zone.zclass, 1)
                    return errormsg, origin, slaves
            else:
                log(2,'FORMERROR(1)')
                errormsg = self.error(msg.header.id, msg.zone.zname,
                                      msg.zone.ztype, msg.zone.zclass, 1)
                return errormsg, origin, slaves

        # now handle actual update
        curserial = zd[msg.zone.zname]['SOA'][0]['serial']
        # update the soa serial here
        clearupdatehist = 0
        if len(msg.uplist) > 0:
            # initialize history structure
            if not self.updates.has_key(zkey):
                self.updates[zkey] = {}
                self.updates[zkey][curserial] = {'removed':[],
                                                 'added':[]}
            if curserial == 2**32:
                newserial = 2
                clearupdatehist = 1
            else:
                newserial = curserial + 1
            self.updates[zkey][newserial] = {'removed':[],
                                             'added':[]}
            zd[msg.zone.zname]['SOA'][0]['serial'] = newserial
        for rr in msg.uplist:
            rrname = rr.keys()[0]
            rrtype = rr[rrname].keys()[0]
            dbrec = rr[rrname][rrtype][0]
            if dbrec['class'] == msg.zone.zclass:
                if rrtype == 'SOA':
                    if zd.has_key(rrname):
                        if zd[rrname].has_key('SOA'):
                            if dbrec['serial'] > zd[rrname]['SOA'][0]['serial']:
                                del zd[rrname]['SOA'][0]
                                zd[rrname]['SOA'].append(dbrec)
                                clearupdatehist = 1
                elif rrtype == 'WKS':
                    if zd.has_key(rrname):
                        if zd[rrname].has_key('WKS'):
                            rdata = zd[rrname]['WKS'][0]
                            oldrr = {rrname:{'WKS':[rdata]}}
                            self.updates[zkey][curserial]['removed'].append(oldrr)
                            del zd[rrname]['WKS'][0]
                            zd[rrname]['WKS'].append(dbrec)
                            newrr = {rrname:{'WKS':[dbrec]}}
                            self.updates[zkey][newserial]['added'].append(newrr)
                else:
                    if zd.has_key(rrname):
                        if not zd[rrname].has_key(rrtype):
                            zd[rrname][rrtype] = []
                    else:
                        zd[rrname] = {}
                        zd[rrname][rrtype] = []
                    zd[rrname][rrtype].append(dbrec)
                    newrr = {rrname:{rrtype:[dbrec]}}
                    self.updates[zkey][newserial]['added'].append(newrr)
            elif dbrec['class'] == 'ANY':
                if rrtype == 'ANY':
                    if rrname == msg.zone.zname:
                        if zd.has_key(rrname):
                            for dnstype in zd[rrname].keys():
                                if dnstype not in ['SOA','NS']:
                                    for rdata in zd[rrname][dnstype]:
                                        oldrr = {rrname:{dnstype:[rdata]}}
                                        self.updates[zkey][curserial]['removed'].append(oldrr)
                                    del zd[rrname][dnstype]                                    
                    else:
                        if zd.has_key(rrname):
                            for dnstype in zd[rrname].keys():
                                for rdata in zd[rrname][dnstype]:
                                    oldrr = {rrname:{dnstype:[rdata]}}
                                    self.updates[zkey][curserial]['removed'].append(oldrr)
                            del zd[rrname]
                else:
                    if zd.has_key(rrname):
                        if zd[rrname].has_key(rrtype):
                            if rrname == msg.zone.zname:
                                if rrtype not in ['SOA','NS']:
                                    for rdata in zd[rrname][dnstype]:
                                        oldrr = {rrname:{dnstype:[rdata]}}
                                        self.updates[zkey][curserial]['removed'].append(oldrr)
                                    del zd[rrname][rrtype]
                            else:
                                for rdata in zd[rrname][dnstype]:
                                    oldrr = {rrname:{dnstype:[rdata]}}
                                    self.updates[zkey][curserial]['removed'].append(oldrr)
                                del zd[rrname][rrtype]
            elif dbrec['class'] == 'NONE':
                if not (rrname == msg.zone.zname and rrtype in ['SOA','NS']):
                    if zd.had_key(rrname):
                        if zd[rrname].has_key(rrtype):
                            for i in range(len(zd[rrname][rrtype])):
                                if dbrec == zd[rrname][rrtype][i]:
                                    rdata = zd[rrname][dnstype][i]
                                    oldrr = {rrname:{dnstype:[rdata]}}
                                    self.updates[zkey][curserial]['removed'].append(oldrr)
                                    del zd[rrname][rrtype][i]
                            if len(zd[rrname][rrtype]) == 0:
                                del zd[rrname][rrtype]
        if clearupdatehist:
            self.updates[zkey] = {}
        log(2,'SENDING UPDATE NOERROR MSG')
        noerrormsg = self.error(msg.header.id, msg.zone.zname,
                              msg.zone.ztype, msg.zone.zclass, 0)
        return noerrormsg, origin, slaves

class dnscache:
    def __init__(self,cachezone):
        self.cachedb = cachezone
        # go through and set all of the root ttls to zero
        for node in self.cachedb.keys():
            for rtype in self.cachedb[node].keys():
                for rr in self.cachedb[node][rtype]:
                    rr['ttl'] = 0
                    if rtype == 'NS':
                        rr['rtt'] = 0
        # add special entries for localhost
        self.cachedb['localhost'] = {'A':[{'address':'127.0.0.1', 'ttl':0, 'class':'IN'}]}
        self.cachedb['1.0.0.127.in-addr.arpa'] = {'PTR':[{'ptrdname':'localhost', 'ttl':0,'class':'IN'}]}
        self.cachedb['']['SOA'] = []
        self.cachedb['']['SOA'].append({'class':'IN','ttl':0,'mname':'cachedb',
                                        'rname':'cachedb@localhost','serial':1,'refresh':10800,
                                        'retry':3600,'expire':604800,'minimum':3600})

    def hasrdata(self, irrdata, rrdatalist):
        # compare everything but ttls
        test = 0
        testrrdata = irrdata.copy()
        del testrrdata['ttl']
        for rrdata in rrdatalist:
            temprrdata = rrdata.copy()
            del temprrdata['ttl']
            if temprrdata == testrrdata:
                test = 1
        return test

    def add(self, rr, qzone, nsdname):
        # NOTE: can't cache records from sites
        # that don't own those records (i.e. example.com
        # can't give us A records for www.example.net)
        name = rr.keys()[0]
        if (qzone != '') and (name[-len(qzone):] != qzone):
            log(2,'cache GOT possible POISON: ' + name + ' for zone ' + qzone)
            return
        rtype = rr[name].keys()[0]
        rdata = rr[name][rtype][0]
        if rdata['ttl'] < 3600:
            log(2,'low ttl: ' + str(rdata['ttl']))
            rdata['ttl'] = 3600
        rdata['ttl'] = int(time.time() + rdata['ttl'])
        if rtype == 'NS':
            rdata['rtt'] = 0
        name = name.lower()
        rtype = rtype.upper()
        if self.cachedb.has_key(name):
            if self.cachedb[name].has_key(rtype):
                if not self.hasrdata(rdata, self.cachedb[name][rtype]):
                    self.cachedb[name][rtype].append(rdata)
                    log(3,'appended rdata to ' +
                        name + '(' + rtype + ') in cache')
                else:
                    log(3,'same rdata for ' + name + '(' +
                        rtype + ') is already in cache')
            else:
                self.cachedb[name][rtype] = [rdata]
                log(3,'appended ' + rtype + ' and rdata to node ' +
                    name + ' in cache')
        else:
            self.cachedb[name] = {rtype:[rdata]}
            log(3,'added node ' + name + '(' + rtype + ') to cache')
        self.reap()

    def addneg(self, qname, querytype, queryclass):
        if not self.cachedb.has_key(qname):
            self.cachedb['qname'] = {querytype: [{'ttl':time.time()+3600}]}
        else:
            if not self.cachedb[qname].has_key(querytype):
                self.cachedb[qname][querytype] = [{'ttl':time.time()+3600}]
    
    def haskey(self, qname, querytype, msg=''):
        log(3,'looking for ' + qname + '(' + querytype + ') in cache')
        if self.cachedb.has_key(qname):
            rranswerlist = []
            rrnslist = []
            rraddlist = []
            if self.cachedb[qname].has_key('CNAME'):
                if querytype != 'CNAME':
                    nodetype = 'CNAME'
                    while nodetype == 'CNAME':
                        if len(self.cachedb[qname]['CNAME'][0].keys()) > 1:
                            log(3,'Adding CNAME to cache answer')
                            rranswerlist.append({qname:{'CNAME':[self.cachedb[qname]['CNAME'][0]]}})
                        qname = self.cachedb[qname]['CNAME'][0]['cname']
                        if self.cachedb.has_key(qname):
                            nodetype = self.cachedb[qname].keys()[0]
                        else:
                            # shouldn't have a CNAME that points to nothing
                            return
            if querytype == 'ANY':
                for type in self.cache[qname].keys():
                    for rec in self.cachedb[qname][type]:
                        # can't append negative entries
                        if len(rec.keys()) > 1:
                            rranswerlist.append({qname:{type:[rec]}})
            elif self.cachedb[qname].has_key(querytype):
                for rec in self.cachedb[qname][querytype]:
                    if len(rec.keys()) > 1:
                        rranswerlist.append({qname:{querytype:[rec]}})
            if rranswerlist:
                if msg:
                    answer = message()
                    answer.header.id = msg.header.id
                    answer.header.qr = 1
                    answer.header.opcode = msg.header.opcode
                    answer.header.ra = 1
                    answer.question.qname = msg.question.qname
                    answer.question.qtype = msg.question.qtype
                    answer.question.qclass = msg.question.qclass
                    answer.header.rcode = 0
                    answer.header.ancount = len(rranswerlist)
                    answer.answerlist = rranswerlist
                    return answer
                else:
                    return 1
        else:
            log(3,'Cache has no node for ' + qname)
        
    def getnslist(self, qname):
        # find the best nameserver to ask from the cache
        tokens = qname.split('.')
        nsdict = {}
        curtime = time.time()
        for i in range(len(tokens)):
            domainname = '.'.join(tokens[i:])
            if self.cachedb.has_key(domainname):
                if self.cachedb[domainname].has_key('NS'):
                    for nsrec in self.cachedb[domainname]['NS']:
                        badserver = 0
                        if nsrec.has_key('badtill'):
                            if nsrec['badtill'] < curtime:
                                del nsrec['badtill']
                            else:
                                badserver = 1
                        if badserver:
                            log(2,'BAD SERVER, not using ' + nsrec['nsdname'])
                        if self.cachedb.has_key(nsrec['nsdname']) and not badserver:
                            if self.cachedb[nsrec['nsdname']].has_key('A'):
                                for arec in self.cachedb[nsrec['nsdname']]['A']:
                                    nsdict[nsrec['rtt']] = {'name':nsrec['nsdname'],
                                                            'ip':arec['address']}
                    if nsdict:
                        break
        if not nsdict:
            domainname = ''
            # nothing in the cache matches so give back the root servers
            for nsrec in self.cachedb['']['NS']:
                badserver = 0
                if nsrec.has_key('badtill'):
                    if curtime > nsrec['badtill']:
                        del nsrec['badtill']
                    else:
                        badserver = 1
                if not badserver:
                    for arec in self.cachedb[nsrec['nsdname']]['A']:
                        nsdict[(nsrec['rtt'])] = {'name':nsrec['nsdname'],'ip':arec['address']}

        return (domainname, nsdict)

    def badns(self, zonename, nsdname):
        if self.cachedb.has_key(zonename):
            if self.cachedb[zonename].has_key('NS'):
                for nsrec in self.cachedb[zonename]['NS']:
                    if nsrec['nsdname'] == nsdname:
                        log(2,'Setting ' + nsdname + ' as bad nameserver')
                        nsrec['badtill'] = time.time() + 3600
        

    def updatertt(self, qname, zone, rtt):
        if self.cachedb.has_key(zone):
            if self.cachedb[zone].has_key('NS'):
                for rr in self.cachedb[zone]['NS']:
                    if rr['nsdname'] == qname:
                        log(2,'updating rtt for ' + qname + ' to ' + str(rtt))
                        rr['rtt'] = rtt

    def reap(self):
        # expire all old records
        ntime = time.time()
        for nodename in self.cachedb.keys():
            for rrtype in self.cachedb[nodename].keys():
                for rdata in self.cachedb[nodename][rrtype]:
                    ttl = rdata['ttl']
                    if ttl != 0:
                        if ttl < ntime:
                            self.cachedb[nodename][rrtype].remove(rdata)
                if len(self.cachedb[nodename][rrtype]) == 0:
                    del self.cachedb[nodename][rrtype]
            if len(self.cachedb[nodename]) == 0:
                del self.cachedb[nodename]
                       
        return

    def zonetrans(self, queryid):
        # build a list of messages
        # each message contains one rr of the zone
        # the first and last message are the
        # SOA records
        zonedata = self.cachedb
        rrlist = []
        soa = {'':{'SOA':[zonedata['']['SOA'][0]]}}
        for nodename in zonedata.keys():
            for rrtype in zonedata[nodename].keys():
                if not (rrtype == 'SOA' and nodename == ''):
                    for rr in zonedata[nodename][rrtype]:
                        rrlist.append({nodename:{rrtype:[rr]}})
        rrlist.insert(0,soa)
        rrlist.append(soa)
        msglist = []
        for rr in rrlist:
            msg = message()
            msg.header.id = queryid
            msg.header.qr = 1
            msg.header.aa = 1
            msg.header.rd = 0
            msg.header.qdcount = 1
            msg.question.qname = 'cache'
            msg.question.qtype = 'AXFR'
            msg.question.qclass = 'IN'
            msg.header.ancount = 1
            msg.answerlist.append(rr)
            msglist.append(msg)
        return msglist

class gethostaddr(asyncore.dispatcher):
    def __init__(self, hostname, cbfunc, serveraddr='127.0.0.1'):
        asyncore.dispatcher.__init__(self)
        self.msg = message()
        self.msg.question.qname = hostname
        self.msg.question.qtype = 'A'
        self.cbfunc = cbfunc
        self.create_socket(socket.AF_INET, socket.SOCK_DGRAM)
        self.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 200 * 1024)
        self.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 200 * 1024)
        self.socket.sendto(self.msg.buildpkt(), (serveraddr,53))

    def handle_read(self):
        replydata, addr = self.socket.recvfrom(1500)
        self.close()
        try:
            replymsg = message(replydata)
        except:
            log(0,'unable to process packet')
            return
        answername = replymsg.question.qname
        cname = ''
        # go through twice to catch cnames after A recs
        for rr in replymsg.answerlist:
            rrname = rr.keys()[0]
            rrtype = rr[rrname].keys()[0]
            dbrec = rr[rrname][rrtype][0]
            if rrname == answername and rrtype == 'CNAME':
                answername = dbrec['cname']
                cname = answername
        for rr in replymsg.answerlist:
            rrname = rr.keys()[0]
            rrtype = rr[rrname].keys()[0]
            dbrec = rr[rrname][rrtype][0]
            if rrname == answername and rrtype == 'A':
                self.cbfunc(dbrec['address'])
                return
        # if we got a cname and no A send query for cname
        if cname:
            self.msg = message()
            self.msg.question.qname = cname
            self.msg.question.qtype = 'A'
            self.socket.sendto(self.msg.buildpkt(), (serveraddr,53))
        else:
            self.cbfunc('')

    def writable(self):
        return 0

    def handle_write(self):
        pass

    def handle_connect(self):
        pass

    def handle_close(self):
        self.close()

    def log_info (self, message, type='info'):
        if __debug__ or type != 'info':
            log(0,'%s: %s' % (type, message))

class simpleudprequest(asyncore.dispatcher):
    def __init__(self, msg, cbfunc, serveraddr='127.0.0.1', outqkey=''):
        asyncore.dispatcher.__init__(self)
        self.gotanswer = 0
        self.msg = msg
        self.cbfunc = cbfunc
        self.create_socket(socket.AF_INET, socket.SOCK_DGRAM)
        self.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 200 * 1024)
        self.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 200 * 1024)
        self.outqkey = outqkey
        self.socket.sendto(self.msg.buildpkt(), (serveraddr,53))

    def handle_read(self):
        replydata, addr = self.socket.recvfrom(1500)
        self.close()
        try:
            replymsg = message(replydata)
        except:
            log(0,'unable to process packet')
            return
        self.cbfunc(replymsg, self.outqkey)

    def writable(self):
        return 0

    def handle_write(self):
        pass

    def handle_connect(self):
        pass

    def handle_close(self):
        self.close()

    def log_info (self, message, type='info'):
        if __debug__ or type != 'info':
            log(0,'%s: %s' % (type, message))

class simpletcprequest(asyncore.dispatcher):
    def __init__(self, msg, cbfunc, cbparams=[], serveraddr='127.0.0.1', errorfunc=''):
        asyncore.dispatcher.__init__(self)
        self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
        self.query = msg
        self.cbfunc = cbfunc
        self.cbparams = cbparams
        self.errorfunc = errorfunc
        msgdata = msg.buildpkt()
        ml = inttoasc(len(msgdata))
        if len(ml) == 1:
            ml = chr(0) + ml
        self.buffer = ml+msgdata
        self.rbuffer = ''
        self.rmsgleft = 0
        self.rrlist = []
        log(2,'sending tcp request to ' + serveraddr)
        self.connect((serveraddr,53))

    def recv (self, buffer_size):
        try:
            data = self.socket.recv (buffer_size)
            if not data:
                # a closed connection is indicated by signaling
                # a read condition, and having recv() return 0.
                self.handle_close()
                return ''
            else:
                return data
        except socket.error, why:
            # winsock sometimes throws ENOTCONN
            if why[0] in [ECONNRESET, ENOTCONN, ESHUTDOWN, ETIMEDOUT]:
                self.handle_close()
                return ''
            else:
                raise socket.error, why

    def handle_connect(self):
        pass

    def handle_msg(self, msg):
        if self.query.question.qtype == 'AXFR':
            if len(self.rrlist) == 0:
                if len(msg.answerlist) == 0:
                    if self.errorfunc:
                        self.errorfunc(self.cbparams[0])
                    self.close()
                    return
            rr = msg.answerlist[0]
            rrname = rr.keys()[0]
            rrtype = rr[rrname].keys()[0]
            self.rrlist.append(rr)
            if rrtype == 'SOA' and len(self.rrlist) > 1:
                self.close()
                if self.cbparams:
                    self.cbfunc(self.rrlist, self.cbparams)
                else:
                    self.cbfunc(self.rrlist)
        else:
            self.close()
            if self.cbparams:
                self.cbfunc(msg, self.cbparams)
            else:
                self.cbfunc(msg)

    def handle_read(self):
        data = self.recv(8192)
        if len(self.rbuffer) == 0:
            self.rmsglength = asctoint(data[:2])
            data = data[2:]
        self.rbuffer = self.rbuffer + data
        while len(self.rbuffer) >= self.rmsglength and self.rmsglength != 0:
            msgdata = self.rbuffer[:self.rmsglength]
            self.rbuffer = self.rbuffer[self.rmsglength:]
            if len(self.rbuffer) == 0:
                self.rmsglength = 0
            else:
                self.rmsglength = asctoint(self.rbuffer[:2])
                self.rbuffer = self.rbuffer[2:]
            try:
                self.handle_msg(message(msgdata))
            except:
                return
            
    def writable(self):
        return (len(self.buffer) > 0)
    
    def handle_write(self):
        sent = self.send(self.buffer)
        self.buffer = self.buffer[sent:]

    def handle_close(self):
        if self.errorfunc:
            self.errorfunc(self.query.question.qname)
        self.close()

    def log_info (self, message, type='info'):
        if __debug__ or type != 'info':
            log(0,'%s: %s' % (type, message))

class udpdnsserver(asyncore.dispatcher):
    def __init__(self, port, dnsserver):
        asyncore.dispatcher.__init__(self)
        self.create_socket(socket.AF_INET, socket.SOCK_DGRAM)
        self.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 200 * 1024)
        self.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 200 * 1024)   
        self.bind(('',port))
        self.dnsserver = dnsserver
        self.maxmsgsize = 500

    def handle_read(self):
        try:
            while 1:
                msgdata, addr = self.socket.recvfrom(1500)
                self.dnsserver.handle_packet(msgdata, addr, self)
        except socket.error, why:
            if why[0] != asyncore.EWOULDBLOCK:
                raise socket.error, why

    def sendpackets(self, msglist, addr):
        for msg in msglist:
            msgdata = msg.buildpkt()
            if len(msgdata) > self.maxmsgsize:
                msg.header.tc = 1
                # take off all the answers to ensure
                # the packet size is small enough
                msg.header.ancount = 0
                msg.header.nscount = 0
                msg.header.arcount = 0
                msg.answerlist = []
                msg.authlist = []
                msg.addlist = []
                msgdata = msg.buildpkt()
            self.sendto(msgdata, addr)
        
    def writable(self):
        return 0

    def handle_write(self):
        pass

    def handle_connect(self):
        pass

    def handle_close(self):
        # print '1:In handle close'
        return

    def log_info (self, message, type='info'):
        if __debug__ or type != 'info':
            log(0,'%s: %s' % (type, message))

class tcpdnschannel(asynchat.async_chat):
    def __init__(self, server, s, addr):
        asynchat.async_chat.__init__(self, s)
        self.server = server
        self.addr = addr
        self.set_terminator(None)
        self.databuffer = ''
        self.msglength = 0
        log(3,'Created new tcp channel')

    def collect_incoming_data(self, data):
        if self.msglength == 0:
            self.msglength = asctoint(data[:2])
            data = data[2:]
        self.databuffer = self.databuffer + data
        if len(self.databuffer) == self.msglength:
            # got entire message
            self.server.dnsserver.handle_packet(self.databuffer, self.addr, self)
            self.databuffer = ''
            
    def sendpackets(self, msglist, addr):
        for msg in msglist:
            x = msg.buildpkt()
            ml = inttoasc(len(x))
            if len(ml) == 1:
                ml = chr(0) + ml
            self.push(ml+x)
        self.close()

    def log_info (self, message, type='info'):
        if __debug__ or type != 'info':
            log(0,'%s: %s' % (type, message))

class tcpdnsserver(asyncore.dispatcher):
    def __init__(self, port, dnsserver):
        asyncore.dispatcher.__init__(self)
        self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
        self.set_reuse_addr()
        self.bind(('',port))
        self.listen(5)
        self.dnsserver = dnsserver

    def handle_accept(self):
        conn, addr = self.accept()
        tcpdnschannel(self, conn, addr)

    def handle_close(self):
        self.close()

    def log_info (self, message, type='info'):
        if __debug__ or type != 'info':
            log(0,'%s: %s' % (type, message))

class nameserver:
    def __init__(self, resolver, localconfig):
        self.resolver = resolver
        self.config = localconfig
        self.zdb = self.config.zonedatabase
        self.last_reap_time = time.time()
        self.maint_int = 10
        self.slavesupdating = []
        self.notifys = []
        self.sentnotify = []
        self.notify_retry_time = 30
        self.notify_retries = 4
        self.askedsoa = {}
        self.soatimeout = 10

    def error(self, id, qname, querytype, queryclass, rcode):
        error = message()
        error.header.id = id
        error.header.rcode = rcode
        error.header.qr = 1
        error.question.qname = qname
        error.question.qtype = querytype
        error.question.qclass = queryclass
        return error

    def need_zonetransfer(self, zkey, origin, masterip, trynum=0):
        self.askedsoa[zkey] = {'masterip':masterip,
                               'senttime':time.time(),
                               'origin':origin,
                               'trynum':trynum+1}
        query = message()
        query.header.id = random.randrange(1,32768)
        query.header.rd = 0
        query.question.qname = origin
        query.question.qtype = 'SOA'
        query.question.qclass = 'IN'
        log(3,'slave checking for new data in ' + origin)
        simpleudprequest(query, self.handle_soaquery,
                         masterip, zkey)

    def handle_soaquery(self, msg, zkey):
        origin = msg.question.qname
        masterip = self.askedsoa[zkey]['masterip']
        del self.askedsoa[zkey]
        if zkey not in self.slavesupdating:
            self.slavesupdating.append(zkey)
            query = message()
            query.header.id = random.randrange(1,32768)
            query.header.rd = 0
            query.question.qname = origin
            query.question.qtype = 'AXFR'
            query.question.qclass = 'IN'
            log(3,'Updating slave zone: ' + zkey)
            simpletcprequest(query, self.handle_zonetrans,
                             [zkey],masterip,self.handle_zterror)

    def handle_zonetrans(self, rrlist, params):
        log(1,'handling zone transfer')
        zonekey = params[0]
        self.zdb.update_zone(rrlist, params)
        self.slavesupdating.remove(zonekey)

    def handle_zterror(self, zonekey):
        self.slavesupdating.remove(zonekey)
        self.zdb.remove_zone(zonekey)

    def rrmatch(self, rrset1, rrset2):
        for rrtype in rrset1.keys():
            if rrtype not in rrset2.keys():
                return
            else:
                if len(rrset1[rrtype]) != len(rrset2[rrtype]):
                    return
        return 1

    def process_notify(self, msg, ipaddr, port):
        (zkeys, dorecursion, flist) = self.config.getview(msg, ipaddr, port)
        goodzkey = ''
        for zkey in zkeys:
            origin = self.zdb.getorigin(zkey)
            if origin == msg.question.qname:
                masterip = self.zdb.getmasterip(zkey)
                if masterip:
                    goodzkey = zkey
        if goodzkey:
            log(3,'got NOTIFY from ' + masterip)
            self.need_zonetransfer(goodzkey, origin, masterip, 0)
        return

    def notify(self):
        curtime = time.time()
        for origin, ipaddr, trynum, senttime in self.sentnotify:
            if senttime + self.notify_retry_time > curtime:
                self.notifys.append((origin, ipaddr, trynum))
                self.sentnotify.remove((origin, ipaddr, trynum, senttime))
        for origin, ipaddr, trynum in self.notifys:
            msg = message()
            msg.question.qname = origin
            msg.question.qtype = 'SOA'
            msg.question.qclass = 'IN'
            msg.header.opcode = 4
            # there probably is a better way to do this
            if self.resolver:
                self.resolver.send_to([msg],(ipaddr,53))
                if trynum+1 <= self.notify_retries:
                    self.sentnotify.append((origin,ipaddr,trynum+1,curtime))
        self.notifys = []
        
    def handle_packet(self, msgdata, addr, server):
        # self.reap()
        try:
            msg = message(msgdata)
        except:
            return
        # find a matching view
        (zkeys, dorecursion, flist) = self.config.getview(msg, addr[0], addr[1])
        if not msg.header.qr and msg.header.opcode == 5:
            log(2,'GOT UPDATE PACKET')
            # check the zone section
            if (msg.header.zocount != 1 or
                msg.zone.ztype != 'SOA' or
                msg.zone.zclass != 'IN'):
                log(2,'SENDING FORMERR UPDATE ERROR')
                errormsg = self.error(msg.header.id, msg.zone.zname,
                                  msg.zone.ztype, msg.zone.zclass, 1)
                server.sendpackets([errormsg],addr)
            else:
                (answer, origin, slaves) = self.zdb.handle_update(msg, addr, self)
                if answer.header.rcode == 0:
                    # schedule NOTIFYs to slaves
                    for ipaddr in slaves:
                        self.notifys.append((origin, ipaddr, 0))
                server.sendpackets([answer],addr)
        elif msg.header.opcode == 4:
            if msg.header.qr:
                log(0,'got NOTIFY response')
                for origin, ipaddr, trynum, senttime in self.sentnotify:
                    if ipaddr == addr[0] and msg.question.qname == origin:
                        self.sentnotify.remove((origin, ipaddr, trynum, senttime))
            else:
                log(0,'got NOTIFY')
                self.process_notify(msg, addr[0], addr[1])
        elif not msg.header.qr and msg.header.opcode == 0:
            # it's a question
            qname = msg.question.qname.lower()
            log(2,'GOT QUERY for ' + qname + '(' + msg.question.qtype +
                ') from ' + addr[0])
            # handle special version packet
            if (msg.question.qtype == 'TXT' and
                msg.question.qclass == 'CH'):
                if qname == 'version.bind':                    
                    server.sendpackets([getversion(qname,
                                                  msg.header.id,
                                                  msg.header.rd,
                                                  dorecursion, '1.0')],addr)
                elif qname == 'version.oak':
                    server.sendpackets([getversion(qname,
                                                  msg.header.id,
                                                  msg.header.rd,
                                                  dorecursion, '1.0')],addr)
                return
            self.zdb.lookup(zkeys, msg, addr, server, dorecursion,
                               flist, self.lookup_callback)

    def lookup_callback(self, msg, addr, server, dorecursion, flist, answerlist):
        if answerlist:
            server.sendpackets(self.config.outpackets(answerlist), addr)
        elif dorecursion:
            if msg.question.qtype in ['AXFR','IXFR']:
                if msg.question.qname == 'cache' and msg.question.qtype == 'AXFR':
                    if self.resolver:
                        server.sendpackets(self.resolver.cache.zonetrans(msg.header.id),addr)
                else:
                    # won't forward zone transfers and
                    # don't handle recursive zone transfers
                    server.sendpackets([self.error(msg.header.id, msg.question.qname,
                                                   msg.question.qtype,
                                                   msg.question.qclass,2)],addr)
            else:
                self.resolver.handle_query(msg, addr, flist, server.sendpackets)
                             
    def reap(self):
        log(4,'in nameserver reap')
        # do all maintenence (interval) stuff here
        if self.resolver:
            self.resolver.reap()
        self.notify()
        curtime = time.time()
        if curtime > (self.last_reap_time + self.maint_int):
            self.last_reap_time = curtime
            # do zone transfers here if slave server and haven't asked for soa
            for (zkey, origin, masterip) in self.zdb.getslaves(curtime):
                if not self.askedsoa.has_key(zkey):
                    self.need_zonetransfer(zkey, origin, masterip)
        for zkey in self.askedsoa.keys():
            if curtime > self.askedsoa[zkey]['senttime'] + self.soatimeout:
                if self.askedsoa[zkey]['trynum'] > 3:
                    self.zdb.remove_zone(zkey)
                    del self.askedsoa[zkey]                    
                else:
                    masterip = self.askedsoa[zkey]['masterip']
                    origin = self.askedsoa[zkey]['origin']
                    trynum = self.askedsoa[zkey]['trynum']
                    del self.askedsoa[zkey]
                    self.need_zonetransfer(zkey, origin, masterip, trynum)
                
    def log_info (self, message, type='info'):
        if __debug__ or type != 'info':
            log(0,'%s: %s' % (type, message))

class resolver(asyncore.dispatcher):
    def __init__(self, cache, port=0):
        asyncore.dispatcher.__init__(self)
        self.create_socket(socket.AF_INET, socket.SOCK_DGRAM)
        self.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 200 * 1024)
        self.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 200 * 1024)   
        self.bind(('',port))
        self.cache = cache
        self.outqnum = 0
        self.outq = {}
        self.holdq = {}
        self.holdtime = 10
        self.holdqlength = 100
        self.last_reap_time = time.time()
        self.maint_int = 10
        self.timeout = 3

    def getoutqkey(self):
        self.outqnum = self.outqnum + 1
        if self.outqnum == 99999:
            self.outqnum = 1
        return str(self.outqnum)

    def error(self, id, qname, querytype, queryclass, rcode):
        error = message()
        error.header.id = id
        error.header.rcode = rcode
        error.header.qr = 1
        error.question.qname = qname
        error.question.qtype = querytype
        error.question.qclass = queryclass
        return error

    def qpacket(self, id, qname, querytype, queryclass):
        # create a question
        query = message()
        query.header.id = id
        query.header.rd = 0
        query.question.qname = qname
        query.question.qtype = querytype
        query.question.qclass = queryclass
        return query

    def send_to(self, msglist, addr):
        for msg in msglist:
            data = msg.buildpkt()
            if len(data) > 512:
                # packet to big
                msg.header.tc = 1
                msg.header.ancount = 0
                msg.answerlist = []
                msg.header.nscount = 0
                msg.authlist = []
                msg.header.arcount = 0
                msg.addlist = []
                self.socket.sendto(msg.buildpkt(), addr)
            else:
                self.socket.sendto(data, addr)

    def handle_read(self):
        try:
            while 1:
                msgdata, addr = self.socket.recvfrom(1500)
                # should put 'try' here in production server
                self.handle_packet(msgdata, addr)
        except socket.error, why:
            if why[0] != asyncore.EWOULDBLOCK:
                raise socket.error, why

    def handle_packet(self, msgdata, addr):
        try:
            msg = message(msgdata)
        except:
            return
        if not msg.header.qr:
            self.handle_query(msg, addr, [], self.send_to)
        else:
            log(2,'received unsolicited reply')


    def handle_query(self, msg, addr, flist, cbfunc):
        qname = msg.question.qname
        querytype = msg.question.qtype
        queryclass = msg.question.qclass
        # check the cache first
        answer = self.cache.haskey(qname,querytype,msg)
        if answer:
            cbfunc([answer], addr)
            log(2,'sent answer for ' + qname + '(' + querytype +
                ') from cache')
        else:
            # check if query is already in progess
            for oqkey in self.outq.keys():
                if (self.outq[oqkey]['qname'] == qname and
                    self.outq[oqkey]['querytype'] == querytype):
                    log(2,'query already in progress for '+qname+'('+querytype+')')
                    # put entry in hold queue to try later
                    hqrec = {'processtime':time.time()+self.holdtime,
                             'query':msg,'addr':addr,
                             'qname':qname,'querytype':querytype,
                             'queryclass':queryclass,
                             'cbfunc':cbfunc}
                    self.putonhold(hqrec)
                    return
                
            outqkey = self.getoutqkey()+str(msg.header.id)                
            self.outq[outqkey] = {'query':msg,
                                  'addr':addr,
                                  'qname':qname,
                                  'querytype':querytype,
                                  'queryclass':queryclass,
                                  'cbfunc':cbfunc,
                                  'answerlist':[],
                                  'addlist':[],
                                  'qsent':0}
            if flist:
                self.outq[outqkey]['flist'] = flist
                self.askfns(outqkey)
            else:
                self.askns(outqkey)

    def putonhold(self,hqrec):
        hqid = hqrec['qname']+hqrec['querytype']
        if self.holdq.has_key(hqid):
            if len(self.holdq[hqid]) < self.holdqlength:
                hqrec['processtime']=time.time()+self.holdtime
                self.holdq[hqid].append(hqrec)
        
            
    def askns(self, outqkey):
        qname = self.outq[outqkey]['qname']
        querytype = self.outq[outqkey]['querytype']
        queryclass = self.outq[outqkey]['queryclass']
        # don't try more than 10 times to avoid loops
        if self.outq[outqkey]['qsent'] == 10:
            del self.outq[outqkey]
            log(2,'Dropping query for ' + qname + '(' + querytype + ')' +
                   ' POSSIBLE LOOP')
            return
        # find the best nameservers to ask from the cache
        (qzone, nsdict) = self.cache.getnslist(qname)
        if not nsdict:
            # there are no good servers
            if self.outq[outqkey]['addr'] != 'IQ':
                qid = self.outq[outqkey]['query'].header.id
                self.outq[outqkey]['cbfunc'](self.error(qid,qname,querytype,queryclass,2),
                                   self.outq[outqkey]['addr'])
            del self.outq[outqkey]
            log(2,'Dropping query for ' + qname + '(' + querytype + ')' +
                   'no good name servers to ask')
            return
        # pick the best nameserver
        rtts = nsdict.keys()
        rtts.sort()
        bestnsip = nsdict[rtts[0]]['ip']
        bestnsname = nsdict[rtts[0]]['name']
        # fill in the callback data structure
        id=random.randrange(1,32768)
        self.outq[outqkey]['nsqueriedlastip'] = bestnsip
        self.outq[outqkey]['nsqueriedlastname'] = bestnsname
        self.outq[outqkey]['nsdict'] = nsdict
        self.outq[outqkey]['qzone'] = qzone
        self.outq[outqkey]['qsenttime'] = time.time()
        self.outq[outqkey]['qsent'] = self.outq[outqkey]['qsent'] + 1
        # self.socket.sendto(self.qpacket(id,qname,querytype,queryclass), (bestnsip,53))
        self.outq[outqkey]['request'] = simpleudprequest(self.qpacket(id,qname,querytype,queryclass),
                                                      self.handle_response, bestnsip, outqkey)
        # update rtt so that we ask a different server next time
        self.cache.updatertt(bestnsname,qzone,1)
        log(2,outqkey+'|sent query to ' + bestnsip + '(' + bestnsname +
               ') for ' + qname + '(' + querytype + ')')

    def askfns(self, outqkey):
        flist = self.outq[outqkey]['flist']
        qname = self.outq[outqkey]['qname']
        querytype = self.outq[outqkey]['querytype']
        queryclass = self.outq[outqkey]['queryclass']
        self.outq[outqkey]['qsenttime'] = time.time()        
        id=random.randrange(1,32768)
        # self.socket.sendto(self.qpacket(id,qname,querytype,queryclass), (flist[0],53))
        self.outq[outqkey]['request'] = simpleudprequest(self.qpacket(id,qname,querytype,queryclass),
                                                         self.handle_fresponse, flist[0], outqkey)
        log(2,''+outqkey+'|sent query to forwarder')

    def handle_response(self, msg, outqkey):
        # either reponse:
        # 1. contains a name error
        # 2. answers the question
        #    (cache data and return it)
        # 3. is (contains) a CNAME and qtype isn't
        #    (cache cname and change qname to it)
        #    (check if qname and qtype are in any other rrs in the response)
        #    (must check cache again here)
        # 4. contains a better delegation
        #    (cache the delegation and start again)
        # 5. is aserver failure
        #    (delete server from list and try again)

        # make sure that original question is still outstanding
        if not self.outq.has_key(outqkey):
            # should never get here
            # if we do we aren't doing housekeeping of callbacks very well
            log(2,''+outqkey+'|got response for a question already answered for ' + msg.question.qname)
            return

        querytype = self.outq[outqkey]['querytype']
        if msg.header.rcode not in [1,2,4,5]:        
            # update rtt time
            rtt = time.time() - self.outq[outqkey]['qsenttime']
            nsname = self.outq[outqkey]['nsqueriedlastname']
            zone = self.outq[outqkey]['qzone']
            self.cache.updatertt(nsname,zone,rtt)

        if msg.header.rcode == 3:
            log(2,outqkey+'|GOT Name Error for ' + msg.question.qname +
                '(' + msg.question.qtype + ')')
            # name error
            # cache negative answer
            self.cache.addneg(self.outq[outqkey]['qname'],
                              self.outq[outqkey]['querytype'],
                              self.outq[outqkey]['queryclass'])
            if self.outq[outqkey]['addr'] != 'IQ':
                answer = message()                
                answer.question.qname = self.outq[outqkey]['query'].question.qname
                answer.question.qtype = self.outq[outqkey]['query'].question.qtype
                answer.question.qclass = self.outq[outqkey]['query'].question.qclass
                answer.header.id = self.outq[outqkey]['query'].header.id
                answer.header.qr = 1
                answer.header.opcode = self.outq[outqkey]['query'].header.opcode
                answer.header.ra = 1
                self.outq[outqkey]['cbfunc']([answer], self.outq[outqkey]['addr'])
            del self.outq[outqkey]
            
        elif msg.header.ancount > 0:
            # answer (may be CNAME)
            haveanswer = 0
            cname = ''
            log(2,'CACHING ANSWERLIST ENTRIES')
            for rr in msg.answerlist:
                rrname = rr.keys()[0]
                rrtype = rr[rrname].keys()[0]
                if ((rrname == msg.question.qname or rrname == cname ) and
                    rrtype == msg.question.qtype):
                    haveanswer = 1
                if rrname == msg.question.qname and rrtype == 'CNAME':
                    cname = rr[rrname][rrtype][0]['cname']
                self.cache.add(rr, self.outq[outqkey]['qzone'],
                               self.outq[outqkey]['nsqueriedlastname'])
            if haveanswer:
                if self.outq[outqkey]['addr'] != 'IQ':
                    log(2,''+outqkey+'|GOT Answer for ' + msg.question.qname +
                        '(' + msg.question.qtype + ')' )
                    answer = message()
                    answer.answerlist = msg.answerlist + self.outq[outqkey]['answerlist']
                    answer.header.ancount = len(answer.answerlist)
                    answer.question.qname = self.outq[outqkey]['query'].question.qname
                    answer.question.qtype = self.outq[outqkey]['query'].question.qtype
                    answer.question.qclass = self.outq[outqkey]['query'].question.qclass
                    answer.header.id = self.outq[outqkey]['query'].header.id
                    answer.header.qr = 1
                    answer.header.opcode = self.outq[outqkey]['query'].header.opcode
                    answer.header.ra = 1
                    self.outq[outqkey]['cbfunc']([answer], self.outq[outqkey]['addr'])
                    log(2,outqkey+'|sent answer retrieved from remote server for ' +
                           self.outq[outqkey]['query'].question.qname)
                else:
                    log(2,outqkey+'|GOT Answer(IQ) for ' + msg.question.qname + '(' +
                        msg.question.qtype + ')')
                del self.outq[outqkey]
            elif cname:
                log(2,outqkey+'|GOT CNAME for ' + msg.question.qname + '(' + msg.question.qtype + ')')
                self.outq[outqkey]['answerlist'] = self.outq[outqkey]['answerlist'] + msg.answerlist
                self.outq[outqkey]['qname'] = cname
                self.askns(outqkey)
            else:
                log(2,outqkey+'|GOT BOGUS answer for '  + msg.question.qname + '(' +
                    msg.question.qtype + ')')
                del self.outq[outqkey]
            
        elif msg.header.nscount > 0 and msg.header.ancount == 0:
            log(2,outqkey+'|GOT DELEGATION for ' + msg.question.qname + '(' + msg.question.qtype + ')')
            # delegation
            # cache the nameserver rrs and start over
            # if there are no glue records for nameservers must fetch them first
            log(2,'CACHING AUTHLIST ENTRIES')
            for rr in msg.authlist:
                self.cache.add(rr,self.outq[outqkey]['qzone'],self.outq[outqkey]['nsqueriedlastname'])
            log(2,'CACHING ADDLIST ENTRIES')
            for rr in msg.addlist:
                self.cache.add(rr,self.outq[outqkey]['qzone'],self.outq[outqkey]['nsqueriedlastname'])
            rrlist = msg.authlist+msg.addlist
            fetchglue = 0
            nscount = 0
            for rr in msg.authlist:
                nodename = rr.keys()[0]
                if rr[nodename].keys()[0] == 'NS':
                    nscount = nscount + 1
                    nsdname = rr[nodename]['NS'][0]['nsdname']
                    if not self.cache.haskey(nsdname,'A'):
                        log(2,outqkey+'|Glue record not in cache for ' + nsdname + '(A)')
                        fetchglue = fetchglue + 1
                        # need to fetch A rec
                        noutqkey = self.getoutqkey()+str(random.randrange(1,32768))
                        self.outq[noutqkey] = {'query':'',
                                               'addr':'IQ',
                                               'qname':nsdname,
                                               'querytype':'A',
                                               'queryclass':'IN',
                                               'qsent':0}
                        log(2,outqkey+'|sending a query to fetch glue records for ' + nsdname + '(A)')
                        self.askns(noutqkey)
            if not nscount:
                log(2,outqkey+'|Dropping query (no ns recs) for ' +
                       msg.question.qname + '(' + msg.question.qtype + ')' )
                del self.outq[outqkey]
            elif fetchglue == nscount:
                log(2,outqkey+'|Stalling query (no glue recs) for ' +
                       msg.question.qname + '(' + msg.question.qtype + ')')
                self.putonhold(self.outq[outqkey])
                del self.outq[outqkey]                
            else:
                log(2,outqkey+'|got (some) glue with delegation')
                self.askns(outqkey)

        elif msg.header.rcode in [1,2,4,5]:
            log(2,outqkey+'|GOT ' + getrcode(msg.header.rcode))
            log(2,'SERVER ' + self.outq[outqkey]['nsqueriedlastname'] + '(' + 
                 self.outq[outqkey]['nsqueriedlastip'] + ') FAILURE for ' + msg.question.qname)
            # don't ask this server for a while
            self.cache.badns(self.outq[outqkey]['qzone'],self.outq[outqkey]['nsqueriedlastname'])
            self.askns(outqkey)
        else:
            log(2,outqkey+'|GOT UNPARSEABLE REPLY')
            msg.printpkt()

    def handle_fresponse(self, msg, outqkey):
        if msg.header.rcode in [1,2,4,5]:
            self.outq[outqkey]['flist'].pop(0)
            if len(self.outq[outqkey]['flist']) == 0:
                qid = self.outq[outqkey]['query'].header.id
                qname = self.outq[outqkey]['qname']
                querytype = self.outq[outqkey]['querytype']
                queryclass = self.outq[outqkey]['queryclass']
                self.outq[outqkey]['cbfunc'](self.error(qid,qname,querytype,queryclass,2),
                                             self.outq[outqkey]['addr'])
                del self.outq[outqkey]
            else:
                self.askfns(outqkey)
        else:
            answer = message()
            answer.header.id = self.outq[outqkey]['query'].header.id
            answer.header.qr = 1
            answer.header.opcode = self.outq[outqkey]['query'].header.opcode
            answer.header.ra = 1
            answer.question.qname = self.outq[outqkey]['query'].question.qname
            answer.question.qtype = self.outq[outqkey]['query'].question.qtype
            answer.question.qclass = self.outq[outqkey]['query'].question.qclass
            answer.header.ancount = msg.header.ancount
            answer.header.nscount = msg.header.nscount
            answer.header.arcount = msg.header.arcount                        
            answer.answerlist = msg.answerlist
            answer.authlist = msg.authlist
            answer.addlist = msg.addlist                
            if msg.header.rcode == 3:
                # name error
                # cache negative answer
                self.cache.addneg(self.outq[outqkey]['qname'],
                                  self.outq[outqkey]['querytype'],
                                  self.outq[outqkey]['queryclass'])
            else:
                # cache all rrs
                for rr in msg.answerlist:
                    self.cache.add(rr,'','forwarder')
                for rr in msg.authlist:
                    self.cache.add(rr,'','forwarder')
                for rr in msg.addlist:
                    self.cache.add(rr,'','forwarder')
            self.outq[outqkey]['cbfunc']([answer], self.outq[outqkey]['addr'])
            del self.outq[outqkey]

    def writable(self):
        return 0

    def handle_write(self):
        pass

    def handle_connect(self):
        pass

    def handle_close(self):
        # print '1:In handle close'
        return

    def process_holdq(self):
        curtime = time.time()
        for hqkey in self.holdq.keys():
            for hqrec in self.holdq[hqkey]:
                if curtime >= hqrec['processtime']:
                    log(2,'processing held query')
                    answer = self.cache.haskey(hqrec['qname'],
                                               hqrec['querytype'],
                                               hqrec['query'])
                    if answer:
                        hqrec['cbfunc']([answer], hqrec['addr'])
                        log(2,'sent answer for ' + hqrec['qname'] +
                            '(' + hqrec['querytype'] +  ') from cache')
                    self.holdq[hqkey].remove(hqrec)
            if len(self.holdq[hqkey]) == 0:
                del self.holdq[hqkey]

    def reap(self):
        self.process_holdq()
        curtime = time.time()
        log(3,timestamp() + 'processed HOLDQ (sockets: ' +
            str(len(asyncore.socket_map.keys()))+')')
        if curtime > (self.last_reap_time + self.maint_int):
            self.last_reap_time = curtime
            for outqkey in self.outq.keys():
                if curtime > self.outq[outqkey]['qsenttime'] + self.timeout:
                    log(2,'query for '+self.outq[outqkey]['qname']+'('+
                        self.outq[outqkey]['querytype']+') expired')
                    # don't set forwarders as bad
                    if not self.outq[outqkey].has_key('flist'):
                        self.cache.badns(self.outq[outqkey]['qzone'],
                                         self.outq[outqkey]['nsqueriedlastname'])
                    if self.outq[outqkey].has_key('request'):
                        log(3,'closing socket for expired query')
                        self.outq[outqkey]['request'].close()
                    del self.outq[outqkey]
        return

    def log_info (self, message, type='info'):
        if __debug__ or type != 'info':
            log(0,'%s: %s' % (type, message))


def run(configobj):
    global loglevel
    r = resolver(dnscache(configobj.cached))
    ns = nameserver(r, configobj)
    udpds = udpdnsserver(53, ns)
    tcpds = tcpdnsserver(53, ns)
    loglevel = configobj.loglevel
    try:
        loop(ns.reap)
    except KeyboardInterrupt:
        print 'server done'

if __name__ == '__main__':
    sipb_xen_database.connect('postgres://sipb-xen@sipb-xen-dev/sipb_xen')
    zonedict = {'example.net':{'origin':'example.net',
                               'filename':'db.example.net',
                               'type':'master',
                               'slaves':[]}}


    zonedict = {'servers.csail.mit.edu':{'origin':'servers.csail.mit.edu',
                                         'filename':'db.servers.csail.mit.edu',
                                         'type':'master',
                                         'slaves':[]}}

    zonedict2 = {'example.net':{'origin':'example.net',
                                'filename':'db.example.net',
                                'type':'slave',
                                'masterip':'127.0.0.1'}}
    readzonefiles(zonedict)
    lconfig = dnsconfig()
    lconfig.zonedatabase = zonedb(zonedict)
    pr = zonefileparser()
    pr.parse('','db.ca')
    lconfig.cached = pr.getzdict()
    lconfig.loglevel = 3

    run(lconfig)
