1 | #!/usr/bin/python |
---|
2 | from twisted.internet import reactor |
---|
3 | from twisted.names import server |
---|
4 | from twisted.names import dns |
---|
5 | from twisted.names import common |
---|
6 | from twisted.names import authority |
---|
7 | from twisted.internet import defer |
---|
8 | from twisted.python import failure |
---|
9 | |
---|
10 | from invirt.common import InvirtConfigError |
---|
11 | from invirt.config import structs as config |
---|
12 | import invirt.database |
---|
13 | import psycopg2 |
---|
14 | import sqlalchemy |
---|
15 | import time |
---|
16 | import re |
---|
17 | |
---|
18 | class DatabaseAuthority(common.ResolverBase): |
---|
19 | """An Authority that is loaded from a file.""" |
---|
20 | |
---|
21 | soa = None |
---|
22 | |
---|
23 | def __init__(self, domains=None, database=None): |
---|
24 | common.ResolverBase.__init__(self) |
---|
25 | if database is not None: |
---|
26 | invirt.database.connect(database) |
---|
27 | else: |
---|
28 | invirt.database.connect() |
---|
29 | if domains is not None: |
---|
30 | self.domains = domains |
---|
31 | else: |
---|
32 | self.domains = config.dns.domains |
---|
33 | ns = config.dns.nameservers[0] |
---|
34 | self.soa = dns.Record_SOA(mname=ns.hostname, |
---|
35 | rname=config.dns.contact.replace('@','.',1), |
---|
36 | serial=1, refresh=3600, retry=900, |
---|
37 | expire=3600000, minimum=21600, ttl=3600) |
---|
38 | self.ns = dns.Record_NS(name=ns.hostname, ttl=3600) |
---|
39 | record = dns.Record_A(address=ns.ip, ttl=3600) |
---|
40 | self.ns1 = dns.RRHeader(ns.hostname, dns.A, dns.IN, |
---|
41 | 3600, record, auth=True) |
---|
42 | |
---|
43 | |
---|
44 | def _lookup(self, name, cls, type, timeout = None): |
---|
45 | for i in range(3): |
---|
46 | try: |
---|
47 | value = self._lookup_unsafe(name, cls, type, timeout = None) |
---|
48 | except (psycopg2.OperationalError, sqlalchemy.exceptions.SQLError): |
---|
49 | if i == 2: |
---|
50 | raise |
---|
51 | print "Reloading database" |
---|
52 | time.sleep(0.5) |
---|
53 | continue |
---|
54 | else: |
---|
55 | return value |
---|
56 | |
---|
57 | def _lookup_unsafe(self, name, cls, type, timeout): |
---|
58 | invirt.database.clear_cache() |
---|
59 | |
---|
60 | ttl = 900 |
---|
61 | name = name.lower() |
---|
62 | |
---|
63 | if name in self.domains: |
---|
64 | domain = name |
---|
65 | else: |
---|
66 | # Look for the longest-matching domain. |
---|
67 | best_domain = '' |
---|
68 | for domain in self.domains: |
---|
69 | if name.endswith('.'+domain) and len(domain) > len(best_domain): |
---|
70 | best_domain = domain |
---|
71 | if best_domain == '': |
---|
72 | if name.endswith('.in-addr.arpa'): |
---|
73 | # Act authoritative for the IP address for reverse resolution requests |
---|
74 | best_domain = name |
---|
75 | else: |
---|
76 | return defer.fail(failure.Failure(dns.DomainError(name))) |
---|
77 | domain = best_domain |
---|
78 | results = [] |
---|
79 | authority = [] |
---|
80 | additional = [self.ns1] |
---|
81 | authority.append(dns.RRHeader(domain, dns.NS, dns.IN, |
---|
82 | 3600, self.ns, auth=True)) |
---|
83 | |
---|
84 | if cls == dns.IN: |
---|
85 | if name.endswith(".in-addr.arpa"): |
---|
86 | if type in (dns.PTR, dns.ALL_RECORDS): |
---|
87 | ip = '.'.join(reversed(name.split('.')[:-2])) |
---|
88 | value = invirt.database.NIC.query.filter_by(ip=ip).first() |
---|
89 | if value and value.hostname: |
---|
90 | hostname = value.hostname |
---|
91 | if '.' not in hostname: |
---|
92 | hostname = hostname + "." + config.dns.domains[0] |
---|
93 | record = dns.Record_PTR(hostname, ttl) |
---|
94 | results.append(dns.RRHeader(name, dns.PTR, dns.IN, |
---|
95 | ttl, record, auth=True)) |
---|
96 | else: # IP address doesn't point to an active host |
---|
97 | return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name))) |
---|
98 | elif type == dns.SOA: |
---|
99 | results.append(dns.RRHeader(domain, dns.SOA, dns.IN, |
---|
100 | ttl, self.soa, auth=True)) |
---|
101 | # FIXME: Should only return success with no records if the name actually exists |
---|
102 | elif name == domain or name == '.'+domain: |
---|
103 | if type in (dns.A, dns.ALL_RECORDS): |
---|
104 | record = dns.Record_A(config.dns.nameservers[0].ip, ttl) |
---|
105 | results.append(dns.RRHeader(name, dns.A, dns.IN, |
---|
106 | ttl, record, auth=True)) |
---|
107 | elif type == dns.NS: |
---|
108 | results.append(dns.RRHeader(domain, dns.NS, dns.IN, |
---|
109 | ttl, self.ns, auth=True)) |
---|
110 | authority = [] |
---|
111 | elif type == dns.SOA: |
---|
112 | results.append(dns.RRHeader(domain, dns.SOA, dns.IN, |
---|
113 | ttl, self.soa, auth=True)) |
---|
114 | else: |
---|
115 | host = name[:-len(domain)-1] |
---|
116 | value = invirt.database.NIC.query.filter_by(hostname=host).first() |
---|
117 | if value: |
---|
118 | ip = value.ip |
---|
119 | else: |
---|
120 | value = invirt.database.Machine.query().filter_by(name=host).first() |
---|
121 | if value: |
---|
122 | ip = value.nics[0].ip |
---|
123 | else: |
---|
124 | return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name))) |
---|
125 | |
---|
126 | if ip is None: |
---|
127 | return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name))) |
---|
128 | |
---|
129 | if type in (dns.A, dns.ALL_RECORDS): |
---|
130 | record = dns.Record_A(ip, ttl) |
---|
131 | results.append(dns.RRHeader(name, dns.A, dns.IN, |
---|
132 | ttl, record, auth=True)) |
---|
133 | elif type == dns.SOA: |
---|
134 | results.append(dns.RRHeader(domain, dns.SOA, dns.IN, |
---|
135 | ttl, self.soa, auth=True)) |
---|
136 | if len(results) == 0: |
---|
137 | authority = [] |
---|
138 | additional = [] |
---|
139 | return defer.succeed((results, authority, additional)) |
---|
140 | else: |
---|
141 | #Doesn't exist |
---|
142 | return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name))) |
---|
143 | |
---|
144 | class QuotingBindAuthority(authority.BindAuthority): |
---|
145 | """ |
---|
146 | A BindAuthority that (almost) deals with quoting correctly |
---|
147 | |
---|
148 | This will catch double quotes as marking the start or end of a |
---|
149 | quoted phrase, unless the double quote is escaped by a backslash |
---|
150 | """ |
---|
151 | # Match either a quoted or unquoted string literal followed by |
---|
152 | # whitespace or the end of line. This yields two groups, one of |
---|
153 | # which has a match, and the other of which is None, depending on |
---|
154 | # whether the string literal was quoted or unquoted; this is what |
---|
155 | # necessitates the subsequent filtering out of groups that are |
---|
156 | # None. |
---|
157 | string_pat = \ |
---|
158 | re.compile(r'"((?:[^"\\]|\\.)*)"|((?:[^\\\s]|\\.)+)(?:\s+|\s*$)') |
---|
159 | |
---|
160 | # For interpreting escapes. |
---|
161 | escape_pat = re.compile(r'\\(.)') |
---|
162 | |
---|
163 | def collapseContinuations(self, lines): |
---|
164 | L = [] |
---|
165 | state = 0 |
---|
166 | for line in lines: |
---|
167 | if state == 0: |
---|
168 | if line.find('(') == -1: |
---|
169 | L.append(line) |
---|
170 | else: |
---|
171 | L.append(line[:line.find('(')]) |
---|
172 | state = 1 |
---|
173 | else: |
---|
174 | if line.find(')') != -1: |
---|
175 | L[-1] += ' ' + line[:line.find(')')] |
---|
176 | state = 0 |
---|
177 | else: |
---|
178 | L[-1] += ' ' + line |
---|
179 | lines = L |
---|
180 | L = [] |
---|
181 | |
---|
182 | for line in lines: |
---|
183 | in_quote = False |
---|
184 | split_line = [] |
---|
185 | for m in self.string_pat.finditer(line): |
---|
186 | [x] = [x for x in m.groups() if x is not None] |
---|
187 | split_line.append(self.escape_pat.sub(r'\1', x)) |
---|
188 | L.append(split_line) |
---|
189 | return filter(None, L) |
---|
190 | |
---|
191 | if '__main__' == __name__: |
---|
192 | resolvers = [] |
---|
193 | try: |
---|
194 | for zone in config.dns.zone_files: |
---|
195 | for origin in config.dns.domains: |
---|
196 | r = QuotingBindAuthority(zone) |
---|
197 | # This sucks, but if I want a generic zone file, I have to |
---|
198 | # reload the information by hand |
---|
199 | r.origin = origin |
---|
200 | lines = open(zone).readlines() |
---|
201 | lines = r.collapseContinuations(r.stripComments(lines)) |
---|
202 | r.parseLines(lines) |
---|
203 | |
---|
204 | resolvers.append(r) |
---|
205 | except InvirtConfigError: |
---|
206 | # Don't care if zone_files isn't defined |
---|
207 | pass |
---|
208 | resolvers.append(DatabaseAuthority()) |
---|
209 | |
---|
210 | verbosity = 0 |
---|
211 | f = server.DNSServerFactory(authorities=resolvers, verbose=verbosity) |
---|
212 | p = dns.DNSDatagramProtocol(f) |
---|
213 | f.noisy = p.noisy = verbosity |
---|
214 | |
---|
215 | reactor.listenUDP(53, p) |
---|
216 | reactor.listenTCP(53, f) |
---|
217 | reactor.run() |
---|