source: trunk/packages/invirt-vnc-server/python/vnc/extauth.py @ 1386

Last change on this file since 1386 was 1386, checked in by broder, 16 years ago

sipb-xen-vnc-server -> invirt-vnc-server

File size: 7.0 KB
Line 
1"""
2Wrapper for Invirt VNC proxying
3"""
4
5# twisted imports
6from twisted.internet import reactor, protocol, defer
7from twisted.python import log
8
9# python imports
10import sys
11import struct
12import string
13import cPickle
14# Python 2.5:
15#import hashlib
16import sha
17import hmac
18import base64
19import socket
20import time
21import get_port
22
23TOKEN_KEY = "0M6W0U1IXexThi5idy8mnkqPKEq1LtEnlK/pZSn0cDrN"
24
25def getPort(name, auth_data):
26    if (auth_data["machine"] == name):
27        port = get_port.findPort(name)
28        if port is None:
29            return 0
30        return int(port.split(':')[1])
31    else:
32        return None
33   
34class VNCAuthOutgoing(protocol.Protocol):
35   
36    def __init__(self,socks):
37        self.socks=socks
38
39    def connectionMade(self):
40        peer = self.transport.getPeer()
41        self.socks.makeReply(200)
42        self.socks.otherConn=self
43
44    def connectionLost(self, reason):
45        self.socks.transport.loseConnection()
46
47    def dataReceived(self,data):
48        self.socks.write(data)
49
50    def write(self,data):
51        self.transport.write(data)
52
53
54class VNCAuth(protocol.Protocol):
55   
56    def __init__(self,server="localhost"):
57        self.server=server
58        self.auth=None
59   
60    def connectionMade(self):
61        self.buf=""
62        self.otherConn=None
63
64    def validateToken(self, token):
65        global TOKEN_KEY
66        self.auth_error = "Invalid token"
67        try:
68            token = base64.urlsafe_b64decode(token)
69            token = cPickle.loads(token)
70            m = hmac.new(TOKEN_KEY, digestmod=sha)
71            m.update(token['data'])
72            if (m.digest() == token['digest']):
73                data = cPickle.loads(token['data'])
74                expires = data["expires"]
75                if (time.time() < expires):
76                    self.auth = data["user"]
77                    self.auth_error = None
78                    self.auth_machine = data["machine"]
79                    self.auth_data = data
80                else:
81                    self.auth_error = "Token has expired; please try logging in again"
82        except (TypeError, cPickle.UnpicklingError):
83            self.auth = None           
84            print sys.exc_info()
85
86    def dataReceived(self,data):
87        if self.otherConn:
88            self.otherConn.write(data)
89            return
90        self.buf=self.buf+data
91        if ('\r\n\r\n' in self.buf) or ('\n\n' in self.buf) or ('\r\r' in self.buf):
92            lines = self.buf.splitlines()
93            args = lines.pop(0).split()
94            command = args.pop(0)
95            headers = {}
96            for line in lines:
97                try:
98                    (header, data) = line.split(": ", 1)
99                    headers[header] = data
100                except ValueError:
101                    pass
102
103            if command == "AUTHTOKEN":
104                user = args[0]
105                token = headers["Auth-token"]
106                if token == "1": #FIXME
107                    self.auth = user
108                    self.makeReply(200, "Authentication successful")
109                else:
110                    self.makeReply(401)
111            elif command == "CONNECTVNC":
112                vmname = args[0]
113                if ("Auth-token" in headers):
114                    token = headers["Auth-token"]
115                    self.validateToken(token)
116                    if self.auth is not None:
117                        port = getPort(vmname, self.auth_data)
118                        if port is not None: # FIXME
119                            if port != 0:
120                                d = self.connectClass(self.server, port, VNCAuthOutgoing, self)
121                                d.addErrback(lambda result, self=self: self.makeReply(404, result.getErrorMessage()))
122                            else:
123                                self.makeReply(404, "Unable to find VNC for VM "+vmname)
124                        else:
125                            self.makeReply(401, "Unauthorized to connect to VM "+vmname)
126                    else:
127                        if self.auth_error:
128                            self.makeReply(401, self.auth_error)
129                        else:
130                            self.makeReply(401, "Invalid token")
131                else:
132                    self.makeReply(401, "Login first")
133            else:
134                self.makeReply(501, "unknown method "+command)
135            self.buf=''
136        if False and '\000' in self.buf[8:]:
137            head,self.buf=self.buf[:8],self.buf[8:]
138            try:
139                version,code,port=struct.unpack("!BBH",head[:4])
140            except struct.error:
141                raise RuntimeError, "struct error with head='%s' and buf='%s'"%(repr(head),repr(self.buf))
142            user,self.buf=string.split(self.buf,"\000",1)
143            if head[4:7]=="\000\000\000": # domain is after
144                server,self.buf=string.split(self.buf,'\000',1)
145                #server=gethostbyname(server)
146            else:
147                server=socket.inet_ntoa(head[4:8])
148            assert version==4, "Bad version code: %s"%version
149            if not self.authorize(code,server,port,user):
150                self.makeReply(91)
151                return
152            if code==1: # CONNECT
153                d = self.connectClass(server, port, SOCKSv4Outgoing, self)
154                d.addErrback(lambda result, self=self: self.makeReply(91))
155            else:
156                raise RuntimeError, "Bad Connect Code: %s" % code
157            assert self.buf=="","hmm, still stuff in buffer... %s" % repr(self.buf)
158
159    def connectionLost(self, reason):
160        if self.otherConn:
161            self.otherConn.transport.loseConnection()
162
163    def authorize(self,code,server,port,user):
164        log.msg("code %s connection to %s:%s (user %s) authorized" % (code,server,port,user))
165        return 1
166
167    def connectClass(self, host, port, klass, *args):
168        return protocol.ClientCreator(reactor, klass, *args).connectTCP(host,port)
169
170    def makeReply(self,reply,message=""):
171        self.transport.write("VNCProxy/1.0 %d %s\r\n\r\n" % (reply, message))
172        if int(reply / 100)!=2: self.transport.loseConnection()
173
174    def write(self,data):
175        self.transport.write(data)
176
177    def log(self,proto,data):
178        peer = self.transport.getPeer()
179        their_peer = self.otherConn.transport.getPeer()
180        print "%s\t%s:%d %s %s:%d\n"%(time.ctime(),
181                                        peer.host,peer.port,
182                                        ((proto==self and '<') or '>'),
183                                        their_peer.host,their_peer.port),
184        while data:
185            p,data=data[:16],data[16:]
186            print string.join(map(lambda x:'%02X'%ord(x),p),' ')+' ',
187            print ((16-len(p))*3*' '),
188            for c in p:
189                if len(repr(c))>3: print '.',
190                else: print c,
191            print ""
192        print ""
193
194
195class VNCAuthFactory(protocol.Factory):
196    """A factory for a VNC auth proxy.
197   
198    Constructor accepts one argument, a log file name.
199    """
200   
201    def __init__(self, server):
202        self.server = server
203   
204    def buildProtocol(self, addr):
205        return VNCAuth(self.server)
206
Note: See TracBrowser for help on using the repository browser.