#!/usr/bin/python
from twisted.internet import reactor
from twisted.names import server
from twisted.names import dns
from twisted.names import common
from twisted.names import authority
from twisted.internet import defer
from twisted.python import failure

from invirt.config import structs as config
import invirt.database
import psycopg2
import sqlalchemy
import time
import re

class DatabaseAuthority(common.ResolverBase):
    """An Authority that is loaded from a file."""

    soa = None

    def __init__(self, domains=None, database=None):
        common.ResolverBase.__init__(self)
        if database is not None:
            invirt.database.connect(database)
        else:
            invirt.database.connect()
        if domains is not None:
            self.domains = domains
        else:
            self.domains = config.dns.domains
        ns = config.dns.nameservers[0]
        self.soa = dns.Record_SOA(mname=ns.hostname,
                                  rname=config.dns.contact.replace('@','.',1),
                                  serial=1, refresh=3600, retry=900,
                                  expire=3600000, minimum=21600, ttl=3600)
        self.ns = dns.Record_NS(name=ns.hostname, ttl=3600)
        record = dns.Record_A(address=ns.ip, ttl=3600)
        self.ns1 = dns.RRHeader(ns.hostname, dns.A, dns.IN,
                                3600, record, auth=True)

    
    def _lookup(self, name, cls, type, timeout = None):
        for i in range(3):
            try:
                value = self._lookup_unsafe(name, cls, type, timeout = None)
            except (psycopg2.OperationalError, sqlalchemy.exceptions.SQLError):
                if i == 2:
                    raise
                print "Reloading database"
                time.sleep(0.5)
                continue
            else:
                return value

    def _lookup_unsafe(self, name, cls, type, timeout):
        invirt.database.clear_cache()
        
        ttl = 900
        name = name.lower()

        if name in self.domains:
            domain = name
        else:
            # Look for the longest-matching domain.  (This works because domain
            # will remain bound after breaking out of the loop.)
            best_domain = ''
            for domain in self.domains:
                if name.endswith('.'+domain) and len(domain) > len(best_domain):
                    best_domain = domain
            if best_domain == '':
                return defer.fail(failure.Failure(dns.DomainError(name)))
            domain = best_domain
        results = []
        authority = []
        additional = [self.ns1]
        authority.append(dns.RRHeader(domain, dns.NS, dns.IN,
                                      3600, self.ns, auth=True))

        if cls == dns.IN:
            host = name[:-len(domain)-1]
            if not host: # Request for the domain itself.
                if type in (dns.A, dns.ALL_RECORDS):
                    record = dns.Record_A(config.dns.nameservers[0].ip, ttl)
                    results.append(dns.RRHeader(name, dns.A, dns.IN, 
                                                ttl, record, auth=True))
                elif type == dns.NS:
                    results.append(dns.RRHeader(domain, dns.NS, dns.IN,
                                                ttl, self.ns, auth=True))
                    authority = []
                elif type == dns.SOA:
                    results.append(dns.RRHeader(domain, dns.SOA, dns.IN,
                                                ttl, self.soa, auth=True))
            else: # Request for a subdomain.
                value = invirt.database.Machine.query().filter_by(name=host).first()
                if value is None or not value.nics:
                    return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
                ip = value.nics[0].ip
                if ip is None:  #Deactivated?
                    return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))

                if type in (dns.A, dns.ALL_RECORDS):
                    record = dns.Record_A(ip, ttl)
                    results.append(dns.RRHeader(name, dns.A, dns.IN, 
                                                ttl, record, auth=True))
                elif type == dns.SOA:
                    results.append(dns.RRHeader(domain, dns.SOA, dns.IN,
                                                ttl, self.soa, auth=True))
            if len(results) == 0:
                authority = []
                additional = []
            return defer.succeed((results, authority, additional))
        else:
            #Doesn't exist
            return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))

class QuotingBindAuthority(authority.BindAuthority):
    """
    A BindAuthority that (almost) deals with quoting correctly
    
    This will catch double quotes as marking the start or end of a
    quoted phrase, unless the double quote is escaped by a backslash
    """
    # Grab everything up to the first whitespace character or
    # quotation mark not proceeded by a backslash
    whitespace_re = re.compile(r'(.*?)([\t\n\x0b\x0c\r ]+|(?<!\\)")')
    def collapseContinuations(self, lines):
        L = []
        state = 0
        for line in lines:
            if state == 0:
                if line.find('(') == -1:
                    L.append(line)
                else:
                    L.append(line[:line.find('(')])
                    state = 1
            else:
                if line.find(')') != -1:
                    L[-1] += ' ' + line[:line.find(')')]
                    state = 0
                else:
                    L[-1] += ' ' + line
        lines = L
        L = []
        for line in lines:
            in_quote = False
            split_line = []
            while len(line) > 0:
                match = self.whitespace_re.match(line)
                if match is None:
                    # If there's no match, that means that there's no
                    # whitespace in the rest of the line, so it should
                    # be treated as a single entity, quoted or not
                    #
                    # This also means that a closing quote isn't
                    # strictly necessary if the line ends the quote
                    substr = line
                    end = ''
                else:
                    substr, end = match.groups()
                
                if in_quote:
                    # If we're in the middle of the quote, the string
                    # we just grabbed belongs at the end of the
                    # previous string
                    #
                    # Including the whitespace! Unless it's not
                    # whitespace and is actually a closequote instead
                    split_line[-1] += substr + (end if end != '"' else '')
                else:
                    # If we're not in the middle of a quote, than this
                    # is the next new string
                    split_line.append(substr)
                
                if end == '"':
                    in_quote = not in_quote
                
                # Then strip off what we just processed
                line = line[len(substr + end):]
            L.append(split_line)
        return filter(None, L)

if '__main__' == __name__:
    resolvers = []
    for zone in config.dns.zone_files:
        for origin in config.dns.domains:
            r = QuotingBindAuthority(zone)
            # This sucks, but if I want a generic zone file, I have to
            # reload the information by hand
            r.origin = origin
            lines = open(zone).readlines()
            lines = r.collapseContinuations(r.stripComments(lines))
            r.parseLines(lines)
            
            resolvers.append(r)
    resolvers.append(DatabaseAuthority())

    verbosity = 0
    f = server.DNSServerFactory(authorities=resolvers, verbose=verbosity)
    p = dns.DNSDatagramProtocol(f)
    f.noisy = p.noisy = verbosity
    
    reactor.listenUDP(53, p)
    reactor.listenTCP(53, f)
    reactor.run()
