source: trunk/vnc/vnc_server/vncexternalauth.py @ 150

Last change on this file since 150 was 125, checked in by quentin, 17 years ago

Correctly verify authentication tokens, and disable backdoor

File size: 7.3 KB
RevLine 
[115]1"""
2Wrapper for sipb-xen VNC proxying
3"""
4
5# twisted imports
6from twisted.internet import reactor, protocol, defer
7from twisted.python import log
8
9# python imports
[125]10import sys
[115]11import struct
12import string
13import cPickle
14# Python 2.5:
15#import hashlib
16import sha
17import hmac
18import base64
19import socket
20import time
21import get_port
22
23TOKEN_KEY = "0M6W0U1IXexThi5idy8mnkqPKEq1LtEnlK/pZSn0cDrN"
24
[125]25def getPort(name, auth_data):
26    if (auth_data["machine"] == name):
27        port = get_port.findPort(name)
28        if port is None:
29            return 0
30        return int(port.split(':')[1])
31    else:
32        return None
33   
[115]34class VNCAuthOutgoing(protocol.Protocol):
35   
36    def __init__(self,socks):
37        self.socks=socks
38
39    def connectionMade(self):
40        peer = self.transport.getPeer()
41        self.socks.makeReply(200)
42        self.socks.otherConn=self
43
44    def connectionLost(self, reason):
45        self.socks.transport.loseConnection()
46
47    def dataReceived(self,data):
48        self.socks.write(data)
49
50    def write(self,data):
51        #self.socks.log(self,data)
52        self.transport.write(data)
53
54
55class VNCAuth(protocol.Protocol):
56   
57    def __init__(self,logging=None,server="localhost"):
58        self.logging=logging
59        self.server=server
60        self.auth=None
61   
62    def connectionMade(self):
63        self.buf=""
64        self.otherConn=None
65
66    def validateToken(self, token):
67        global TOKEN_KEY
[125]68        try:
69            token = base64.urlsafe_b64decode(token)
70            token = cPickle.loads(token)
71            m = hmac.new(TOKEN_KEY, digestmod=sha)
72            m.update(token['data'])
73            self.auth_error = "Invalid token"
74            if (m.digest() == token['digest']):
75                data = cPickle.loads(token['data'])
76                expires = data["expires"]
77                if (time.time() < expires):
78                    self.auth = data["user"]
79                    self.auth_error = None
80                    self.auth_machine = data["machine"]
81                    self.auth_data = data
82                else:
83                    self.auth_error = "Token has expired; please try logging in again"
84        except:
85            self.auth = None
86            print sys.exc_info()
[115]87
88    def dataReceived(self,data):
89        if self.otherConn:
90            self.otherConn.write(data)
91            return
92        self.buf=self.buf+data
93        if ('\r\n\r\n' in self.buf) or ('\n\n' in self.buf) or ('\r\r' in self.buf):
94            lines = self.buf.splitlines()
95            args = lines.pop(0).split()
96            command = args.pop(0)
97            headers = {}
98            for line in lines:
99                try:
100                    (header, data) = line.split(": ", 1)
101                    headers[header] = data
102                except:
103                    pass
104
105            if command == "AUTHTOKEN":
106                user = args[0]
107                token = headers["Auth-token"]
108                if token == "1": #FIXME
109                    self.auth = user
110                    self.makeReply(200, "Authentication successful")
111                else:
112                    self.makeReply(401)
113            elif command == "CONNECTVNC":
114                vmname = args[0]
115                if ("Auth-token" in headers):
116                    token = headers["Auth-token"]
117                    try:
118                        self.validateToken(token)
119                    finally:
120                        if self.auth is not None:
[125]121                            port = getPort(vmname, self.auth_data)
[115]122                            if port is not None: # FIXME
[125]123                                if port is not 0:
124                                    d = self.connectClass(self.server, port, VNCAuthOutgoing, self)
125                                    d.addErrback(lambda result, self=self: self.makeReply(404, result.getErrorMessage()))
126                                else:
127                                    self.makeReply(404, "Unable to find VNC for VM "+vmname)
[115]128                            else:
129                                self.makeReply(401, "Unauthorized to connect to VM "+vmname)
130                        else:
[125]131                            if self.auth_error:
132                                self.makeReply(401, self.auth_error)
133                            else:
134                                self.makeReply(401, "Invalid token")
[115]135                else:
136                    self.makeReply(401, "Login first")
137            else:
138                self.makeReply(501, "unknown method "+command)
139            self.buf=''
140        if False and '\000' in self.buf[8:]:
141            head,self.buf=self.buf[:8],self.buf[8:]
142            try:
143                version,code,port=struct.unpack("!BBH",head[:4])
144            except struct.error:
145                raise RuntimeError, "struct error with head='%s' and buf='%s'"%(repr(head),repr(self.buf))
146            user,self.buf=string.split(self.buf,"\000",1)
147            if head[4:7]=="\000\000\000": # domain is after
148                server,self.buf=string.split(self.buf,'\000',1)
149                #server=gethostbyname(server)
150            else:
151                server=socket.inet_ntoa(head[4:8])
152            assert version==4, "Bad version code: %s"%version
153            if not self.authorize(code,server,port,user):
154                self.makeReply(91)
155                return
156            if code==1: # CONNECT
157                d = self.connectClass(server, port, SOCKSv4Outgoing, self)
158                d.addErrback(lambda result, self=self: self.makeReply(91))
159            else:
160                raise RuntimeError, "Bad Connect Code: %s" % code
161            assert self.buf=="","hmm, still stuff in buffer... %s" % repr(self.buf)
162
163    def connectionLost(self, reason):
164        if self.otherConn:
165            self.otherConn.transport.loseConnection()
166
167    def authorize(self,code,server,port,user):
168        log.msg("code %s connection to %s:%s (user %s) authorized" % (code,server,port,user))
169        return 1
170
171    def connectClass(self, host, port, klass, *args):
172        return protocol.ClientCreator(reactor, klass, *args).connectTCP(host,port)
173
174    def makeReply(self,reply,message=""):
175        self.transport.write("VNCProxy/1.0 %d %s\r\n\r\n" % (reply, message))
176        if int(reply / 100)!=2: self.transport.loseConnection()
177
178    def write(self,data):
179        #self.log(self,data)
180        self.transport.write(data)
181
182    def log(self,proto,data):
183        if not self.logging: return
184        peer = self.transport.getPeer()
185        their_peer = self.otherConn.transport.getPeer()
186        f=open(self.logging,"a")
187        f.write("%s\t%s:%d %s %s:%d\n"%(time.ctime(),
188                                        peer.host,peer.port,
189                                        ((proto==self and '<') or '>'),
190                                        their_peer.host,their_peer.port))
191        while data:
192            p,data=data[:16],data[16:]
193            f.write(string.join(map(lambda x:'%02X'%ord(x),p),' ')+' ')
194            f.write((16-len(p))*3*' ')
195            for c in p:
196                if len(repr(c))>3: f.write('.')
197                else: f.write(c)
198            f.write('\n')
199        f.write('\n')
200        f.close()
201
202
203class VNCAuthFactory(protocol.Factory):
204    """A factory for a VNC auth proxy.
205   
206    Constructor accepts one argument, a log file name.
207    """
208   
209    def __init__(self, log, server):
210        self.logging = log
211        self.server = server
212   
213    def buildProtocol(self, addr):
214        return VNCAuth(self.logging, self.server)
215
Note: See TracBrowser for help on using the repository browser.