source: trunk/packages/sipb-xen-vnc-server/code/vncexternalauth.py @ 288

Last change on this file since 288 was 288, checked in by broder, 17 years ago

Fixed the init scripts and control file. Also, imported quentin's UNCOMMITTED CODE

File size: 7.3 KB
Line 
1"""
2Wrapper for sipb-xen VNC proxying
3"""
4
5# twisted imports
6from twisted.internet import reactor, protocol, defer
7from twisted.python import log
8
9# python imports
10import sys
11import struct
12import string
13import cPickle
14# Python 2.5:
15#import hashlib
16import sha
17import hmac
18import base64
19import socket
20import time
21import get_port
22
23TOKEN_KEY = "0M6W0U1IXexThi5idy8mnkqPKEq1LtEnlK/pZSn0cDrN"
24
25def getPort(name, auth_data):
26    if (auth_data["machine"] == name):
27        port = get_port.findPort(name)
28        if port is None:
29            return 0
30        return int(port.split(':')[1])
31    else:
32        return None
33   
34class VNCAuthOutgoing(protocol.Protocol):
35   
36    def __init__(self,socks):
37        self.socks=socks
38
39    def connectionMade(self):
40        peer = self.transport.getPeer()
41        self.socks.makeReply(200)
42        self.socks.otherConn=self
43
44    def connectionLost(self, reason):
45        self.socks.transport.loseConnection()
46
47    def dataReceived(self,data):
48        #self.socks.log(self,"R"+data)
49        self.socks.write(data)
50
51    def write(self,data):
52        #self.socks.log(self,'W'+data)
53        self.transport.write(data)
54
55
56class VNCAuth(protocol.Protocol):
57   
58    def __init__(self,logging=None,server="localhost"):
59        self.logging=logging
60        self.server=server
61        self.auth=None
62   
63    def connectionMade(self):
64        self.buf=""
65        self.otherConn=None
66
67    def validateToken(self, token):
68        global TOKEN_KEY
69        self.auth_error = "Invalid token"
70        try:
71            token = base64.urlsafe_b64decode(token)
72            token = cPickle.loads(token)
73            m = hmac.new(TOKEN_KEY, digestmod=sha)
74            m.update(token['data'])
75            if (m.digest() == token['digest']):
76                data = cPickle.loads(token['data'])
77                expires = data["expires"]
78                if (time.time() < expires):
79                    self.auth = data["user"]
80                    self.auth_error = None
81                    self.auth_machine = data["machine"]
82                    self.auth_data = data
83                else:
84                    self.auth_error = "Token has expired; please try logging in again"
85        except (TypeError, cPickle.UnpicklingError):
86            self.auth = None           
87            print sys.exc_info()
88
89    def dataReceived(self,data):
90        if self.otherConn:
91            self.otherConn.write(data)
92            return
93        self.buf=self.buf+data
94        if ('\r\n\r\n' in self.buf) or ('\n\n' in self.buf) or ('\r\r' in self.buf):
95            lines = self.buf.splitlines()
96            args = lines.pop(0).split()
97            command = args.pop(0)
98            headers = {}
99            for line in lines:
100                try:
101                    (header, data) = line.split(": ", 1)
102                    headers[header] = data
103                except ValueError:
104                    pass
105
106            if command == "AUTHTOKEN":
107                user = args[0]
108                token = headers["Auth-token"]
109                if token == "1": #FIXME
110                    self.auth = user
111                    self.makeReply(200, "Authentication successful")
112                else:
113                    self.makeReply(401)
114            elif command == "CONNECTVNC":
115                vmname = args[0]
116                if ("Auth-token" in headers):
117                    token = headers["Auth-token"]
118                    self.validateToken(token)
119                    if self.auth is not None:
120                        port = getPort(vmname, self.auth_data)
121                        if port is not None: # FIXME
122                            if port != 0:
123                                d = self.connectClass(self.server, port, VNCAuthOutgoing, self)
124                                d.addErrback(lambda result, self=self: self.makeReply(404, result.getErrorMessage()))
125                            else:
126                                self.makeReply(404, "Unable to find VNC for VM "+vmname)
127                        else:
128                            self.makeReply(401, "Unauthorized to connect to VM "+vmname)
129                    else:
130                        if self.auth_error:
131                            self.makeReply(401, self.auth_error)
132                        else:
133                            self.makeReply(401, "Invalid token")
134                else:
135                    self.makeReply(401, "Login first")
136            else:
137                self.makeReply(501, "unknown method "+command)
138            self.buf=''
139        if False and '\000' in self.buf[8:]:
140            head,self.buf=self.buf[:8],self.buf[8:]
141            try:
142                version,code,port=struct.unpack("!BBH",head[:4])
143            except struct.error:
144                raise RuntimeError, "struct error with head='%s' and buf='%s'"%(repr(head),repr(self.buf))
145            user,self.buf=string.split(self.buf,"\000",1)
146            if head[4:7]=="\000\000\000": # domain is after
147                server,self.buf=string.split(self.buf,'\000',1)
148                #server=gethostbyname(server)
149            else:
150                server=socket.inet_ntoa(head[4:8])
151            assert version==4, "Bad version code: %s"%version
152            if not self.authorize(code,server,port,user):
153                self.makeReply(91)
154                return
155            if code==1: # CONNECT
156                d = self.connectClass(server, port, SOCKSv4Outgoing, self)
157                d.addErrback(lambda result, self=self: self.makeReply(91))
158            else:
159                raise RuntimeError, "Bad Connect Code: %s" % code
160            assert self.buf=="","hmm, still stuff in buffer... %s" % repr(self.buf)
161
162    def connectionLost(self, reason):
163        if self.otherConn:
164            self.otherConn.transport.loseConnection()
165
166    def authorize(self,code,server,port,user):
167        log.msg("code %s connection to %s:%s (user %s) authorized" % (code,server,port,user))
168        return 1
169
170    def connectClass(self, host, port, klass, *args):
171        return protocol.ClientCreator(reactor, klass, *args).connectTCP(host,port)
172
173    def makeReply(self,reply,message=""):
174        self.transport.write("VNCProxy/1.0 %d %s\r\n\r\n" % (reply, message))
175        if int(reply / 100)!=2: self.transport.loseConnection()
176
177    def write(self,data):
178        #self.log(self,data)
179        self.transport.write(data)
180
181    def log(self,proto,data):
182        if not self.logging: return
183        peer = self.transport.getPeer()
184        their_peer = self.otherConn.transport.getPeer()
185        f=open(self.logging,"a")
186        f.write("%s\t%s:%d %s %s:%d\n"%(time.ctime(),
187                                        peer.host,peer.port,
188                                        ((proto==self and '<') or '>'),
189                                        their_peer.host,their_peer.port))
190        while data:
191            p,data=data[:16],data[16:]
192            f.write(string.join(map(lambda x:'%02X'%ord(x),p),' ')+' ')
193            f.write((16-len(p))*3*' ')
194            for c in p:
195                if len(repr(c))>3: f.write('.')
196                else: f.write(c)
197            f.write('\n')
198        f.write('\n')
199        f.close()
200
201
202class VNCAuthFactory(protocol.Factory):
203    """A factory for a VNC auth proxy.
204   
205    Constructor accepts one argument, a log file name.
206    """
207   
208    def __init__(self, log, server):
209        self.logging = log
210        self.server = server
211   
212    def buildProtocol(self, addr):
213        return VNCAuth(self.logging, self.server)
214
Note: See TracBrowser for help on using the repository browser.