""" Wrapper for Invirt VNC proxying """ # twisted imports from twisted.internet import reactor, protocol, defer from twisted.python import log # python imports import sys import struct import string import cPickle # Python 2.5: #import hashlib import sha import hmac import base64 import socket import time import get_port TOKEN_KEY = "0M6W0U1IXexThi5idy8mnkqPKEq1LtEnlK/pZSn0cDrN" def getPort(name, auth_data): if (auth_data["machine"] == name): port = get_port.findPort(name) if port is None: return 0 return int(port.split(':')[1]) else: return None class VNCAuthOutgoing(protocol.Protocol): def __init__(self,socks): self.socks=socks def connectionMade(self): peer = self.transport.getPeer() self.socks.makeReply(200) self.socks.otherConn=self def connectionLost(self, reason): self.socks.transport.loseConnection() def dataReceived(self,data): self.socks.write(data) def write(self,data): self.transport.write(data) class VNCAuth(protocol.Protocol): def __init__(self,server="localhost"): self.server=server self.auth=None def connectionMade(self): self.buf="" self.otherConn=None def validateToken(self, token): global TOKEN_KEY self.auth_error = "Invalid token" try: token = base64.urlsafe_b64decode(token) token = cPickle.loads(token) m = hmac.new(TOKEN_KEY, digestmod=sha) m.update(token['data']) if (m.digest() == token['digest']): data = cPickle.loads(token['data']) expires = data["expires"] if (time.time() < expires): self.auth = data["user"] self.auth_error = None self.auth_machine = data["machine"] self.auth_data = data else: self.auth_error = "Token has expired; please try logging in again" except (TypeError, cPickle.UnpicklingError): self.auth = None print sys.exc_info() def dataReceived(self,data): if self.otherConn: self.otherConn.write(data) return self.buf=self.buf+data if ('\r\n\r\n' in self.buf) or ('\n\n' in self.buf) or ('\r\r' in self.buf): lines = self.buf.splitlines() args = lines.pop(0).split() command = args.pop(0) headers = {} for line in lines: try: (header, data) = line.split(": ", 1) headers[header] = data except ValueError: pass if command == "AUTHTOKEN": user = args[0] token = headers["Auth-token"] if token == "1": #FIXME self.auth = user self.makeReply(200, "Authentication successful") else: self.makeReply(401) elif command == "CONNECTVNC": vmname = args[0] if ("Auth-token" in headers): token = headers["Auth-token"] self.validateToken(token) if self.auth is not None: port = getPort(vmname, self.auth_data) if port is not None: # FIXME if port != 0: d = self.connectClass(self.server, port, VNCAuthOutgoing, self) d.addErrback(lambda result, self=self: self.makeReply(404, result.getErrorMessage())) else: self.makeReply(404, "Unable to find VNC for VM "+vmname) else: self.makeReply(401, "Unauthorized to connect to VM "+vmname) else: if self.auth_error: self.makeReply(401, self.auth_error) else: self.makeReply(401, "Invalid token") else: self.makeReply(401, "Login first") else: self.makeReply(501, "unknown method "+command) self.buf='' if False and '\000' in self.buf[8:]: head,self.buf=self.buf[:8],self.buf[8:] try: version,code,port=struct.unpack("!BBH",head[:4]) except struct.error: raise RuntimeError, "struct error with head='%s' and buf='%s'"%(repr(head),repr(self.buf)) user,self.buf=string.split(self.buf,"\000",1) if head[4:7]=="\000\000\000": # domain is after server,self.buf=string.split(self.buf,'\000',1) #server=gethostbyname(server) else: server=socket.inet_ntoa(head[4:8]) assert version==4, "Bad version code: %s"%version if not self.authorize(code,server,port,user): self.makeReply(91) return if code==1: # CONNECT d = self.connectClass(server, port, SOCKSv4Outgoing, self) d.addErrback(lambda result, self=self: self.makeReply(91)) else: raise RuntimeError, "Bad Connect Code: %s" % code assert self.buf=="","hmm, still stuff in buffer... %s" % repr(self.buf) def connectionLost(self, reason): if self.otherConn: self.otherConn.transport.loseConnection() def authorize(self,code,server,port,user): log.msg("code %s connection to %s:%s (user %s) authorized" % (code,server,port,user)) return 1 def connectClass(self, host, port, klass, *args): return protocol.ClientCreator(reactor, klass, *args).connectTCP(host,port) def makeReply(self,reply,message=""): self.transport.write("VNCProxy/1.0 %d %s\r\n\r\n" % (reply, message)) if int(reply / 100)!=2: self.transport.loseConnection() def write(self,data): self.transport.write(data) def log(self,proto,data): peer = self.transport.getPeer() their_peer = self.otherConn.transport.getPeer() print "%s\t%s:%d %s %s:%d\n"%(time.ctime(), peer.host,peer.port, ((proto==self and '<') or '>'), their_peer.host,their_peer.port), while data: p,data=data[:16],data[16:] print string.join(map(lambda x:'%02X'%ord(x),p),' ')+' ', print ((16-len(p))*3*' '), for c in p: if len(repr(c))>3: print '.', else: print c, print "" print "" class VNCAuthFactory(protocol.Factory): """A factory for a VNC auth proxy. Constructor accepts one argument, a log file name. """ def __init__(self, server): self.server = server def buildProtocol(self, addr): return VNCAuth(self.server)