source: trunk/packages/invirt-vnc-server/python/vnc/extauth.py @ 1403

Last change on this file since 1403 was 1403, checked in by broder, 16 years ago

File reads are cheap. Don't cache the VNC token key in the server code

File size: 7.0 KB
Line 
1"""
2Wrapper for Invirt VNC proxying
3"""
4
5# twisted imports
6from twisted.internet import reactor, protocol, defer
7from twisted.python import log
8
9# python imports
10import sys
11import struct
12import string
13import cPickle
14# Python 2.5:
15#import hashlib
16import sha
17import hmac
18import base64
19import socket
20import time
21
22def getTokenKey():
23    return file('/etc/invirt/secrets/vnc-key').read().strip()
24
25def getPort(name, auth_data):
26    import get_port
27    if (auth_data["machine"] == name):
28        port = get_port.findPort(name)
29        if port is None:
30            return 0
31        return int(port.split(':')[1])
32    else:
33        return None
34   
35class VNCAuthOutgoing(protocol.Protocol):
36   
37    def __init__(self,socks):
38        self.socks=socks
39
40    def connectionMade(self):
41        peer = self.transport.getPeer()
42        self.socks.makeReply(200)
43        self.socks.otherConn=self
44
45    def connectionLost(self, reason):
46        self.socks.transport.loseConnection()
47
48    def dataReceived(self,data):
49        self.socks.write(data)
50
51    def write(self,data):
52        self.transport.write(data)
53
54
55class VNCAuth(protocol.Protocol):
56   
57    def __init__(self,server="localhost"):
58        self.server=server
59        self.auth=None
60   
61    def connectionMade(self):
62        self.buf=""
63        self.otherConn=None
64
65    def validateToken(self, token):
66        self.auth_error = "Invalid token"
67        try:
68            token = base64.urlsafe_b64decode(token)
69            token = cPickle.loads(token)
70            m = hmac.new(getTokenKey(), digestmod=sha)
71            m.update(token['data'])
72            if (m.digest() == token['digest']):
73                data = cPickle.loads(token['data'])
74                expires = data["expires"]
75                if (time.time() < expires):
76                    self.auth = data["user"]
77                    self.auth_error = None
78                    self.auth_machine = data["machine"]
79                    self.auth_data = data
80                else:
81                    self.auth_error = "Token has expired; please try logging in again"
82        except (TypeError, cPickle.UnpicklingError):
83            self.auth = None           
84            print sys.exc_info()
85
86    def dataReceived(self,data):
87        if self.otherConn:
88            self.otherConn.write(data)
89            return
90        self.buf=self.buf+data
91        if ('\r\n\r\n' in self.buf) or ('\n\n' in self.buf) or ('\r\r' in self.buf):
92            lines = self.buf.splitlines()
93            args = lines.pop(0).split()
94            command = args.pop(0)
95            headers = {}
96            for line in lines:
97                try:
98                    (header, data) = line.split(": ", 1)
99                    headers[header] = data
100                except ValueError:
101                    pass
102
103            if command == "AUTHTOKEN":
104                user = args[0]
105                token = headers["Auth-token"]
106                if token == "1": #FIXME
107                    self.auth = user
108                    self.makeReply(200, "Authentication successful")
109                else:
110                    self.makeReply(401)
111            elif command == "CONNECTVNC":
112                vmname = args[0]
113                if ("Auth-token" in headers):
114                    token = headers["Auth-token"]
115                    self.validateToken(token)
116                    if self.auth is not None:
117                        port = getPort(vmname, self.auth_data)
118                        if port is not None: # FIXME
119                            if port != 0:
120                                d = self.connectClass(self.server, port, VNCAuthOutgoing, self)
121                                d.addErrback(lambda result, self=self: self.makeReply(404, result.getErrorMessage()))
122                            else:
123                                self.makeReply(404, "Unable to find VNC for VM "+vmname)
124                        else:
125                            self.makeReply(401, "Unauthorized to connect to VM "+vmname)
126                    else:
127                        if self.auth_error:
128                            self.makeReply(401, self.auth_error)
129                        else:
130                            self.makeReply(401, "Invalid token")
131                else:
132                    self.makeReply(401, "Login first")
133            else:
134                self.makeReply(501, "unknown method "+command)
135            self.buf=''
136        if False and '\000' in self.buf[8:]:
137            head,self.buf=self.buf[:8],self.buf[8:]
138            try:
139                version,code,port=struct.unpack("!BBH",head[:4])
140            except struct.error:
141                raise RuntimeError, "struct error with head='%s' and buf='%s'"%(repr(head),repr(self.buf))
142            user,self.buf=string.split(self.buf,"\000",1)
143            if head[4:7]=="\000\000\000": # domain is after
144                server,self.buf=string.split(self.buf,'\000',1)
145                #server=gethostbyname(server)
146            else:
147                server=socket.inet_ntoa(head[4:8])
148            assert version==4, "Bad version code: %s"%version
149            if not self.authorize(code,server,port,user):
150                self.makeReply(91)
151                return
152            if code==1: # CONNECT
153                d = self.connectClass(server, port, SOCKSv4Outgoing, self)
154                d.addErrback(lambda result, self=self: self.makeReply(91))
155            else:
156                raise RuntimeError, "Bad Connect Code: %s" % code
157            assert self.buf=="","hmm, still stuff in buffer... %s" % repr(self.buf)
158
159    def connectionLost(self, reason):
160        if self.otherConn:
161            self.otherConn.transport.loseConnection()
162
163    def authorize(self,code,server,port,user):
164        log.msg("code %s connection to %s:%s (user %s) authorized" % (code,server,port,user))
165        return 1
166
167    def connectClass(self, host, port, klass, *args):
168        return protocol.ClientCreator(reactor, klass, *args).connectTCP(host,port)
169
170    def makeReply(self,reply,message=""):
171        self.transport.write("VNCProxy/1.0 %d %s\r\n\r\n" % (reply, message))
172        if int(reply / 100)!=2: self.transport.loseConnection()
173
174    def write(self,data):
175        self.transport.write(data)
176
177    def log(self,proto,data):
178        peer = self.transport.getPeer()
179        their_peer = self.otherConn.transport.getPeer()
180        print "%s\t%s:%d %s %s:%d\n"%(time.ctime(),
181                                        peer.host,peer.port,
182                                        ((proto==self and '<') or '>'),
183                                        their_peer.host,their_peer.port),
184        while data:
185            p,data=data[:16],data[16:]
186            print string.join(map(lambda x:'%02X'%ord(x),p),' ')+' ',
187            print ((16-len(p))*3*' '),
188            for c in p:
189                if len(repr(c))>3: print '.',
190                else: print c,
191            print ""
192        print ""
193
194
195class VNCAuthFactory(protocol.Factory):
196    """A factory for a VNC auth proxy.
197   
198    Constructor accepts one argument, a log file name.
199    """
200   
201    def __init__(self, server):
202        self.server = server
203   
204    def buildProtocol(self, addr):
205        return VNCAuth(self.server)
206
Note: See TracBrowser for help on using the repository browser.