source: trunk/packages/invirt-vnc-server/python/vnc/extauth.py @ 1402

Last change on this file since 1402 was 1388, checked in by broder, 16 years ago

Generate the VNC token key at invirt-vnc-server install-time instead
of hard-coding

File size: 7.1 KB
Line 
1"""
2Wrapper for Invirt VNC proxying
3"""
4
5# twisted imports
6from twisted.internet import reactor, protocol, defer
7from twisted.python import log
8
9# python imports
10import sys
11import struct
12import string
13import cPickle
14# Python 2.5:
15#import hashlib
16import sha
17import hmac
18import base64
19import socket
20import time
21
22def getTokenKey():
23    token_key = file('/etc/invirt/secrets/vnc-key').read().strip()
24    while True:
25        yield token_key
26getTokenKey = getTokenKey().next
27
28def getPort(name, auth_data):
29    import get_port
30    if (auth_data["machine"] == name):
31        port = get_port.findPort(name)
32        if port is None:
33            return 0
34        return int(port.split(':')[1])
35    else:
36        return None
37   
38class VNCAuthOutgoing(protocol.Protocol):
39   
40    def __init__(self,socks):
41        self.socks=socks
42
43    def connectionMade(self):
44        peer = self.transport.getPeer()
45        self.socks.makeReply(200)
46        self.socks.otherConn=self
47
48    def connectionLost(self, reason):
49        self.socks.transport.loseConnection()
50
51    def dataReceived(self,data):
52        self.socks.write(data)
53
54    def write(self,data):
55        self.transport.write(data)
56
57
58class VNCAuth(protocol.Protocol):
59   
60    def __init__(self,server="localhost"):
61        self.server=server
62        self.auth=None
63   
64    def connectionMade(self):
65        self.buf=""
66        self.otherConn=None
67
68    def validateToken(self, token):
69        self.auth_error = "Invalid token"
70        try:
71            token = base64.urlsafe_b64decode(token)
72            token = cPickle.loads(token)
73            m = hmac.new(getTokenKey(), digestmod=sha)
74            m.update(token['data'])
75            if (m.digest() == token['digest']):
76                data = cPickle.loads(token['data'])
77                expires = data["expires"]
78                if (time.time() < expires):
79                    self.auth = data["user"]
80                    self.auth_error = None
81                    self.auth_machine = data["machine"]
82                    self.auth_data = data
83                else:
84                    self.auth_error = "Token has expired; please try logging in again"
85        except (TypeError, cPickle.UnpicklingError):
86            self.auth = None           
87            print sys.exc_info()
88
89    def dataReceived(self,data):
90        if self.otherConn:
91            self.otherConn.write(data)
92            return
93        self.buf=self.buf+data
94        if ('\r\n\r\n' in self.buf) or ('\n\n' in self.buf) or ('\r\r' in self.buf):
95            lines = self.buf.splitlines()
96            args = lines.pop(0).split()
97            command = args.pop(0)
98            headers = {}
99            for line in lines:
100                try:
101                    (header, data) = line.split(": ", 1)
102                    headers[header] = data
103                except ValueError:
104                    pass
105
106            if command == "AUTHTOKEN":
107                user = args[0]
108                token = headers["Auth-token"]
109                if token == "1": #FIXME
110                    self.auth = user
111                    self.makeReply(200, "Authentication successful")
112                else:
113                    self.makeReply(401)
114            elif command == "CONNECTVNC":
115                vmname = args[0]
116                if ("Auth-token" in headers):
117                    token = headers["Auth-token"]
118                    self.validateToken(token)
119                    if self.auth is not None:
120                        port = getPort(vmname, self.auth_data)
121                        if port is not None: # FIXME
122                            if port != 0:
123                                d = self.connectClass(self.server, port, VNCAuthOutgoing, self)
124                                d.addErrback(lambda result, self=self: self.makeReply(404, result.getErrorMessage()))
125                            else:
126                                self.makeReply(404, "Unable to find VNC for VM "+vmname)
127                        else:
128                            self.makeReply(401, "Unauthorized to connect to VM "+vmname)
129                    else:
130                        if self.auth_error:
131                            self.makeReply(401, self.auth_error)
132                        else:
133                            self.makeReply(401, "Invalid token")
134                else:
135                    self.makeReply(401, "Login first")
136            else:
137                self.makeReply(501, "unknown method "+command)
138            self.buf=''
139        if False and '\000' in self.buf[8:]:
140            head,self.buf=self.buf[:8],self.buf[8:]
141            try:
142                version,code,port=struct.unpack("!BBH",head[:4])
143            except struct.error:
144                raise RuntimeError, "struct error with head='%s' and buf='%s'"%(repr(head),repr(self.buf))
145            user,self.buf=string.split(self.buf,"\000",1)
146            if head[4:7]=="\000\000\000": # domain is after
147                server,self.buf=string.split(self.buf,'\000',1)
148                #server=gethostbyname(server)
149            else:
150                server=socket.inet_ntoa(head[4:8])
151            assert version==4, "Bad version code: %s"%version
152            if not self.authorize(code,server,port,user):
153                self.makeReply(91)
154                return
155            if code==1: # CONNECT
156                d = self.connectClass(server, port, SOCKSv4Outgoing, self)
157                d.addErrback(lambda result, self=self: self.makeReply(91))
158            else:
159                raise RuntimeError, "Bad Connect Code: %s" % code
160            assert self.buf=="","hmm, still stuff in buffer... %s" % repr(self.buf)
161
162    def connectionLost(self, reason):
163        if self.otherConn:
164            self.otherConn.transport.loseConnection()
165
166    def authorize(self,code,server,port,user):
167        log.msg("code %s connection to %s:%s (user %s) authorized" % (code,server,port,user))
168        return 1
169
170    def connectClass(self, host, port, klass, *args):
171        return protocol.ClientCreator(reactor, klass, *args).connectTCP(host,port)
172
173    def makeReply(self,reply,message=""):
174        self.transport.write("VNCProxy/1.0 %d %s\r\n\r\n" % (reply, message))
175        if int(reply / 100)!=2: self.transport.loseConnection()
176
177    def write(self,data):
178        self.transport.write(data)
179
180    def log(self,proto,data):
181        peer = self.transport.getPeer()
182        their_peer = self.otherConn.transport.getPeer()
183        print "%s\t%s:%d %s %s:%d\n"%(time.ctime(),
184                                        peer.host,peer.port,
185                                        ((proto==self and '<') or '>'),
186                                        their_peer.host,their_peer.port),
187        while data:
188            p,data=data[:16],data[16:]
189            print string.join(map(lambda x:'%02X'%ord(x),p),' ')+' ',
190            print ((16-len(p))*3*' '),
191            for c in p:
192                if len(repr(c))>3: print '.',
193                else: print c,
194            print ""
195        print ""
196
197
198class VNCAuthFactory(protocol.Factory):
199    """A factory for a VNC auth proxy.
200   
201    Constructor accepts one argument, a log file name.
202    """
203   
204    def __init__(self, server):
205        self.server = server
206   
207    def buildProtocol(self, addr):
208        return VNCAuth(self.server)
209
Note: See TracBrowser for help on using the repository browser.