source: trunk/packages/invirt-dns/invirt-dns @ 1974

Last change on this file since 1974 was 1974, checked in by broder, 15 years ago

DNS lookups first look in the nics table and then the machines table.

This allows VMs with multiple IPs to have DNS names associated with
both IPs.

  • Property svn:executable set to *
File size: 7.2 KB
Line 
1#!/usr/bin/python
2from twisted.internet import reactor
3from twisted.names import server
4from twisted.names import dns
5from twisted.names import common
6from twisted.names import authority
7from twisted.internet import defer
8from twisted.python import failure
9
10from invirt.config import structs as config
11import invirt.database
12import psycopg2
13import sqlalchemy
14import time
15import re
16
17class DatabaseAuthority(common.ResolverBase):
18    """An Authority that is loaded from a file."""
19
20    soa = None
21
22    def __init__(self, domains=None, database=None):
23        common.ResolverBase.__init__(self)
24        if database is not None:
25            invirt.database.connect(database)
26        else:
27            invirt.database.connect()
28        if domains is not None:
29            self.domains = domains
30        else:
31            self.domains = config.dns.domains
32        ns = config.dns.nameservers[0]
33        self.soa = dns.Record_SOA(mname=ns.hostname,
34                                  rname=config.dns.contact.replace('@','.',1),
35                                  serial=1, refresh=3600, retry=900,
36                                  expire=3600000, minimum=21600, ttl=3600)
37        self.ns = dns.Record_NS(name=ns.hostname, ttl=3600)
38        record = dns.Record_A(address=ns.ip, ttl=3600)
39        self.ns1 = dns.RRHeader(ns.hostname, dns.A, dns.IN,
40                                3600, record, auth=True)
41
42   
43    def _lookup(self, name, cls, type, timeout = None):
44        for i in range(3):
45            try:
46                value = self._lookup_unsafe(name, cls, type, timeout = None)
47            except (psycopg2.OperationalError, sqlalchemy.exceptions.SQLError):
48                if i == 2:
49                    raise
50                print "Reloading database"
51                time.sleep(0.5)
52                continue
53            else:
54                return value
55
56    def _lookup_unsafe(self, name, cls, type, timeout):
57        invirt.database.clear_cache()
58       
59        ttl = 900
60        name = name.lower()
61
62        if name in self.domains:
63            domain = name
64        else:
65            # Look for the longest-matching domain.  (This works because domain
66            # will remain bound after breaking out of the loop.)
67            best_domain = ''
68            for domain in self.domains:
69                if name.endswith('.'+domain) and len(domain) > len(best_domain):
70                    best_domain = domain
71            if best_domain == '':
72                return defer.fail(failure.Failure(dns.DomainError(name)))
73            domain = best_domain
74        results = []
75        authority = []
76        additional = [self.ns1]
77        authority.append(dns.RRHeader(domain, dns.NS, dns.IN,
78                                      3600, self.ns, auth=True))
79
80        if cls == dns.IN:
81            host = name[:-len(domain)-1]
82            if not host: # Request for the domain itself.
83                if type in (dns.A, dns.ALL_RECORDS):
84                    record = dns.Record_A(config.dns.nameservers[0].ip, ttl)
85                    results.append(dns.RRHeader(name, dns.A, dns.IN, 
86                                                ttl, record, auth=True))
87                elif type == dns.NS:
88                    results.append(dns.RRHeader(domain, dns.NS, dns.IN,
89                                                ttl, self.ns, auth=True))
90                    authority = []
91                elif type == dns.SOA:
92                    results.append(dns.RRHeader(domain, dns.SOA, dns.IN,
93                                                ttl, self.soa, auth=True))
94            else: # Request for a subdomain.
95                value = invirt.database.NIC.query.filter_by(hostname=host).first()
96                if value:
97                    ip = value.ip
98                else:
99                    value = invirt.database.Machine.query().filter_by(name=host).first()
100                    if value:
101                        ip = value.nics[0].ip
102                    else:
103                        return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
104               
105                if ip is None:
106                    return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
107
108                if type in (dns.A, dns.ALL_RECORDS):
109                    record = dns.Record_A(ip, ttl)
110                    results.append(dns.RRHeader(name, dns.A, dns.IN, 
111                                                ttl, record, auth=True))
112                elif type == dns.SOA:
113                    results.append(dns.RRHeader(domain, dns.SOA, dns.IN,
114                                                ttl, self.soa, auth=True))
115            if len(results) == 0:
116                authority = []
117                additional = []
118            return defer.succeed((results, authority, additional))
119        else:
120            #Doesn't exist
121            return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
122
123class QuotingBindAuthority(authority.BindAuthority):
124    """
125    A BindAuthority that (almost) deals with quoting correctly
126   
127    This will catch double quotes as marking the start or end of a
128    quoted phrase, unless the double quote is escaped by a backslash
129    """
130    # Match either a quoted or unquoted string literal followed by
131    # whitespace or the end of line.  This yields two groups, one of
132    # which has a match, and the other of which is None, depending on
133    # whether the string literal was quoted or unquoted; this is what
134    # necessitates the subsequent filtering out of groups that are
135    # None.
136    string_pat = \
137            re.compile(r'"((?:[^"\\]|\\.)*)"|((?:[^\\\s]|\\.)+)(?:\s+|\s*$)')
138
139    # For interpreting escapes.
140    escape_pat = re.compile(r'\\(.)')
141
142    def collapseContinuations(self, lines):
143        L = []
144        state = 0
145        for line in lines:
146            if state == 0:
147                if line.find('(') == -1:
148                    L.append(line)
149                else:
150                    L.append(line[:line.find('(')])
151                    state = 1
152            else:
153                if line.find(')') != -1:
154                    L[-1] += ' ' + line[:line.find(')')]
155                    state = 0
156                else:
157                    L[-1] += ' ' + line
158        lines = L
159        L = []
160
161        for line in lines:
162            in_quote = False
163            split_line = []
164            for m in self.string_pat.finditer(line):
165                [x] = [x for x in m.groups() if x is not None]
166                split_line.append(self.escape_pat.sub(r'\1', x))
167            L.append(split_line)
168        return filter(None, L)
169
170if '__main__' == __name__:
171    resolvers = []
172    for zone in config.dns.zone_files:
173        for origin in config.dns.domains:
174            r = QuotingBindAuthority(zone)
175            # This sucks, but if I want a generic zone file, I have to
176            # reload the information by hand
177            r.origin = origin
178            lines = open(zone).readlines()
179            lines = r.collapseContinuations(r.stripComments(lines))
180            r.parseLines(lines)
181           
182            resolvers.append(r)
183    resolvers.append(DatabaseAuthority())
184
185    verbosity = 0
186    f = server.DNSServerFactory(authorities=resolvers, verbose=verbosity)
187    p = dns.DNSDatagramProtocol(f)
188    f.noisy = p.noisy = verbosity
189   
190    reactor.listenUDP(53, p)
191    reactor.listenTCP(53, f)
192    reactor.run()
Note: See TracBrowser for help on using the repository browser.