#!/usr/bin/python
from twisted.internet import reactor
from twisted.names import server
from twisted.names import dns
from twisted.names import common
from twisted.internet import defer
from twisted.python import failure

from invirt.config import structs as config
import invirt.database
import psycopg2
import sqlalchemy
import time

class DatabaseAuthority(common.ResolverBase):
    """An Authority that is loaded from a file."""

    soa = None

    def __init__(self, domains=None, database=None):
        common.ResolverBase.__init__(self)
        if database is not None:
            invirt.database.connect(database)
        else:
            invirt.database.connect()
        if domains is not None:
            self.domains = domains
        else:
            self.domains = config.dns.domains
        ns = config.dns.nameservers[0]
        self.soa = dns.Record_SOA(mname=ns.hostname,
                                  rname=config.dns.contact.replace('@','.',1),
                                  serial=1, refresh=3600, retry=900,
                                  expire=3600000, minimum=21600, ttl=3600)
        self.ns = dns.Record_NS(name=ns.hostname, ttl=3600)
        record = dns.Record_A(address=ns.ip, ttl=3600)
        self.ns1 = dns.RRHeader(ns.hostname, dns.A, dns.IN,
                                3600, record, auth=True)

    
    def _lookup(self, name, cls, type, timeout = None):
        for i in range(3):
            try:
                value = self._lookup_unsafe(name, cls, type, timeout = None)
            except (psycopg2.OperationalError, sqlalchemy.exceptions.SQLError):
                if i == 2:
                    raise
                print "Reloading database"
                time.sleep(0.5)
                continue
            else:
                return value

    def _lookup_unsafe(self, name, cls, type, timeout):
        invirt.database.clear_cache()
        
        ttl = 900
        name = name.lower()
        if name in self.domains:
            domain = name
        else:
            # This works because domain will remain bound after breaking out of the loop
            for domain in self.domains:
                if name.endswith('.'+domain):
                    break
            else: #Not us
                return defer.fail(failure.Failure(dns.DomainError(name)))
        results = []
        authority = []
        additional = [self.ns1]
        authority.append(dns.RRHeader(domain, dns.NS, dns.IN,
                                      3600, self.ns, auth=True))
        if cls == dns.IN:
            host = name[:-len(domain)-1]
            if not host:
                if type in (dns.A, dns.ALL_RECORDS):
                    record = dns.Record_A(config.dns.nameservers[0].ip, ttl)
                    results.append(dns.RRHeader(name, dns.A, dns.IN, 
                                                ttl, record, auth=True))
                elif type == dns.NS:
                    results.append(dns.RRHeader(domain, dns.NS, dns.IN,
                                                ttl, self.ns, auth=True))
                    authority = []
                elif type == dns.SOA:
                    results.append(dns.RRHeader(domain, dns.SOA, dns.IN,
                                                ttl, self.soa, auth=True))
            else:
                if host:
                    value = invirt.database.Machine.get_by(name=host)
                    if value is None or not value.nics:
                        return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
                    ip = value.nics[0].ip
                    if ip is None:  #Deactivated?
                        return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
                if type in (dns.A, dns.ALL_RECORDS):
                    record = dns.Record_A(ip, ttl)
                    results.append(dns.RRHeader(name, dns.A, dns.IN, 
                                                ttl, record, auth=True))
                elif type == dns.SOA:
                    results.append(dns.RRHeader(domain, dns.SOA, dns.IN,
                                                ttl, self.soa, auth=True))
            if len(results) == 0:
                authority = []
                additional = []
            return defer.succeed((results, authority, additional))
        else:
            #Doesn't exist
            return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))

if '__main__' == __name__:
    resolver = DatabaseAuthority()

    verbosity = 0
    f = server.DNSServerFactory(authorities=[resolver], verbose=verbosity)
    p = dns.DNSDatagramProtocol(f)
    f.noisy = p.noisy = verbosity
    
    reactor.listenUDP(53, p)
    reactor.listenTCP(53, f)
    reactor.run()
