source: trunk/vnc/vnc_server/vncexternalauth.py @ 116

Last change on this file since 116 was 115, checked in by ecprice, 17 years ago

VNC server commit.

File size: 6.6 KB
RevLine 
[115]1"""
2Wrapper for sipb-xen VNC proxying
3"""
4
5# twisted imports
6from twisted.internet import reactor, protocol, defer
7from twisted.python import log
8
9# python imports
10import struct
11import string
12import cPickle
13# Python 2.5:
14#import hashlib
15import sha
16import hmac
17import base64
18import socket
19import time
20import get_port
21
22TOKEN_KEY = "0M6W0U1IXexThi5idy8mnkqPKEq1LtEnlK/pZSn0cDrN"
23
24def getPort(name, auth):
25    port = get_port.findPort(name)
26    if port is None:
27        return 0
28    return int(port.split(':')[1])
29
30class VNCAuthOutgoing(protocol.Protocol):
31   
32    def __init__(self,socks):
33        self.socks=socks
34
35    def connectionMade(self):
36        peer = self.transport.getPeer()
37        self.socks.makeReply(200)
38        self.socks.otherConn=self
39
40    def connectionLost(self, reason):
41        self.socks.transport.loseConnection()
42
43    def dataReceived(self,data):
44        self.socks.write(data)
45
46    def write(self,data):
47        #self.socks.log(self,data)
48        self.transport.write(data)
49
50
51class VNCAuth(protocol.Protocol):
52   
53    def __init__(self,logging=None,server="localhost"):
54        self.logging=logging
55        self.server=server
56        self.auth=None
57   
58    def connectionMade(self):
59        self.buf=""
60        self.otherConn=None
61
62    def validateToken(self, token):
63        global TOKEN_KEY
64        if token == "quentin":
65            self.auth = "quentin@ATHENA.MIT.EDU"
66            return #FIXME
67        token = base64.urlsafe_b64decode(token)
68        token = cPickle.load(token)
69        m = hmac.new(TOKEN_KEY, digestmod=sha)
70        m.update(token['data'])
71        if (m.digest() == token['digest']):
72            data = cPickle.load(token['data'])
73            expires = data["expires"]
74            if (time.time() < expires):
75                self.auth = data["user"]
76
77    def dataReceived(self,data):
78        if self.otherConn:
79            self.otherConn.write(data)
80            return
81        self.buf=self.buf+data
82        if ('\r\n\r\n' in self.buf) or ('\n\n' in self.buf) or ('\r\r' in self.buf):
83            lines = self.buf.splitlines()
84            args = lines.pop(0).split()
85            command = args.pop(0)
86            headers = {}
87            for line in lines:
88                try:
89                    (header, data) = line.split(": ", 1)
90                    headers[header] = data
91                except:
92                    pass
93
94            if command == "AUTHTOKEN":
95                user = args[0]
96                token = headers["Auth-token"]
97                if token == "1": #FIXME
98                    self.auth = user
99                    self.makeReply(200, "Authentication successful")
100                else:
101                    self.makeReply(401)
102            elif command == "CONNECTVNC":
103                vmname = args[0]
104                if ("Auth-token" in headers):
105                    token = headers["Auth-token"]
106                    try:
107                        self.validateToken(token)
108                    finally:
109                        if self.auth is not None:
110                            port = getPort(vmname, self.auth)
111                            if port is not None: # FIXME
112                                d = self.connectClass(self.server, port, VNCAuthOutgoing, self)
113                                d.addErrback(lambda result, self=self: self.makeReply(404, result.getErrorMessage()))
114                            else:
115                                self.makeReply(401, "Unauthorized to connect to VM "+vmname)
116                        else:
117                            self.makeReply(401, "Invalid token")
118                else:
119                    self.makeReply(401, "Login first")
120            else:
121                self.makeReply(501, "unknown method "+command)
122            self.buf=''
123        if False and '\000' in self.buf[8:]:
124            head,self.buf=self.buf[:8],self.buf[8:]
125            try:
126                version,code,port=struct.unpack("!BBH",head[:4])
127            except struct.error:
128                raise RuntimeError, "struct error with head='%s' and buf='%s'"%(repr(head),repr(self.buf))
129            user,self.buf=string.split(self.buf,"\000",1)
130            if head[4:7]=="\000\000\000": # domain is after
131                server,self.buf=string.split(self.buf,'\000',1)
132                #server=gethostbyname(server)
133            else:
134                server=socket.inet_ntoa(head[4:8])
135            assert version==4, "Bad version code: %s"%version
136            if not self.authorize(code,server,port,user):
137                self.makeReply(91)
138                return
139            if code==1: # CONNECT
140                d = self.connectClass(server, port, SOCKSv4Outgoing, self)
141                d.addErrback(lambda result, self=self: self.makeReply(91))
142            else:
143                raise RuntimeError, "Bad Connect Code: %s" % code
144            assert self.buf=="","hmm, still stuff in buffer... %s" % repr(self.buf)
145
146    def connectionLost(self, reason):
147        if self.otherConn:
148            self.otherConn.transport.loseConnection()
149
150    def authorize(self,code,server,port,user):
151        log.msg("code %s connection to %s:%s (user %s) authorized" % (code,server,port,user))
152        return 1
153
154    def connectClass(self, host, port, klass, *args):
155        return protocol.ClientCreator(reactor, klass, *args).connectTCP(host,port)
156
157    def makeReply(self,reply,message=""):
158        self.transport.write("VNCProxy/1.0 %d %s\r\n\r\n" % (reply, message))
159        if int(reply / 100)!=2: self.transport.loseConnection()
160
161    def write(self,data):
162        #self.log(self,data)
163        self.transport.write(data)
164
165    def log(self,proto,data):
166        if not self.logging: return
167        peer = self.transport.getPeer()
168        their_peer = self.otherConn.transport.getPeer()
169        f=open(self.logging,"a")
170        f.write("%s\t%s:%d %s %s:%d\n"%(time.ctime(),
171                                        peer.host,peer.port,
172                                        ((proto==self and '<') or '>'),
173                                        their_peer.host,their_peer.port))
174        while data:
175            p,data=data[:16],data[16:]
176            f.write(string.join(map(lambda x:'%02X'%ord(x),p),' ')+' ')
177            f.write((16-len(p))*3*' ')
178            for c in p:
179                if len(repr(c))>3: f.write('.')
180                else: f.write(c)
181            f.write('\n')
182        f.write('\n')
183        f.close()
184
185
186class VNCAuthFactory(protocol.Factory):
187    """A factory for a VNC auth proxy.
188   
189    Constructor accepts one argument, a log file name.
190    """
191   
192    def __init__(self, log, server):
193        self.logging = log
194        self.server = server
195   
196    def buildProtocol(self, addr):
197        return VNCAuth(self.logging, self.server)
198
Note: See TracBrowser for help on using the repository browser.