"""
Wrapper for Invirt VNC proxying
"""

# twisted imports
from twisted.internet import reactor, protocol, defer
from twisted.python import log

# python imports
import sys
import struct
import string
import cPickle
# Python 2.5:
#import hashlib
import sha
import hmac
import base64
import socket
import time

def getTokenKey():
    return file('/etc/invirt/vnc/token-key').read().strip()

def getPort(name, auth_data):
    import get_port
    if (auth_data["machine"] == name):
        port = get_port.findPort(name)
        if port is None:
            return 0
        return int(port.split(':')[1])
    else:
        return None
    
class VNCAuthOutgoing(protocol.Protocol):
    
    def __init__(self,socks):
        self.socks=socks

    def connectionMade(self):
        peer = self.transport.getPeer()
        self.socks.makeReply(200)
        self.socks.otherConn=self

    def connectionLost(self, reason):
        self.socks.transport.loseConnection()

    def dataReceived(self,data):
        self.socks.write(data)

    def write(self,data):
        self.transport.write(data)


class VNCAuth(protocol.Protocol):
    
    def __init__(self,server="localhost"):
        self.server=server
        self.auth=None
    
    def connectionMade(self):
        self.buf=""
        self.otherConn=None

    def validateToken(self, token):
        self.auth_error = "Invalid token"
        try:
            token = base64.urlsafe_b64decode(token)
            token = cPickle.loads(token)
            m = hmac.new(getTokenKey(), digestmod=sha)
            m.update(token['data'])
            if (m.digest() == token['digest']):
                data = cPickle.loads(token['data'])
                expires = data["expires"]
                if (time.time() < expires):
                    self.auth = data["user"]
                    self.auth_error = None
                    self.auth_machine = data["machine"]
                    self.auth_data = data
                else:
                    self.auth_error = "Token has expired; please try logging in again"
        except (TypeError, cPickle.UnpicklingError):
            self.auth = None            
            print sys.exc_info()

    def dataReceived(self,data):
        if self.otherConn:
            self.otherConn.write(data)
            return
        self.buf=self.buf+data
        if ('\r\n\r\n' in self.buf) or ('\n\n' in self.buf) or ('\r\r' in self.buf):
            lines = self.buf.splitlines()
            args = lines.pop(0).split()
            command = args.pop(0)
            headers = {}
            for line in lines:
                try:
                    (header, data) = line.split(": ", 1)
                    headers[header] = data
                except ValueError:
                    pass

            if command == "AUTHTOKEN":
                user = args[0]
                token = headers["Auth-token"]
                if token == "1": #FIXME
                    self.auth = user
                    self.makeReply(200, "Authentication successful")
                else:
                    self.makeReply(401)
            elif command == "CONNECTVNC":
                vmname = args[0]
                if ("Auth-token" in headers):
                    token = headers["Auth-token"]
                    self.validateToken(token)
                    if self.auth is not None:
                        port = getPort(vmname, self.auth_data)
                        if port is not None: # FIXME
                            if port != 0:
                                d = self.connectClass(self.server, port, VNCAuthOutgoing, self)
                                d.addErrback(lambda result, self=self: self.makeReply(404, result.getErrorMessage()))
                            else:
                                self.makeReply(404, "Unable to find VNC for VM "+vmname)
                        else:
                            self.makeReply(401, "Unauthorized to connect to VM "+vmname)
                    else:
                        if self.auth_error:
                            self.makeReply(401, self.auth_error)
                        else:
                            self.makeReply(401, "Invalid token")
                else:
                    self.makeReply(401, "Login first")
            else:
                self.makeReply(501, "unknown method "+command)
            self.buf=''
        if False and '\000' in self.buf[8:]:
            head,self.buf=self.buf[:8],self.buf[8:]
            try:
                version,code,port=struct.unpack("!BBH",head[:4])
            except struct.error:
                raise RuntimeError, "struct error with head='%s' and buf='%s'"%(repr(head),repr(self.buf))
            user,self.buf=string.split(self.buf,"\000",1)
            if head[4:7]=="\000\000\000": # domain is after
                server,self.buf=string.split(self.buf,'\000',1)
                #server=gethostbyname(server)
            else:
                server=socket.inet_ntoa(head[4:8])
            assert version==4, "Bad version code: %s"%version
            if not self.authorize(code,server,port,user):
                self.makeReply(91)
                return
            if code==1: # CONNECT
                d = self.connectClass(server, port, SOCKSv4Outgoing, self)
                d.addErrback(lambda result, self=self: self.makeReply(91))
            else:
                raise RuntimeError, "Bad Connect Code: %s" % code
            assert self.buf=="","hmm, still stuff in buffer... %s" % repr(self.buf)

    def connectionLost(self, reason):
        if self.otherConn:
            self.otherConn.transport.loseConnection()

    def authorize(self,code,server,port,user):
        log.msg("code %s connection to %s:%s (user %s) authorized" % (code,server,port,user))
        return 1

    def connectClass(self, host, port, klass, *args):
        return protocol.ClientCreator(reactor, klass, *args).connectTCP(host,port)

    def makeReply(self,reply,message=""):
        self.transport.write("VNCProxy/1.0 %d %s\r\n\r\n" % (reply, message))
        if int(reply / 100)!=2: self.transport.loseConnection()

    def write(self,data):
        self.transport.write(data)

    def log(self,proto,data):
        peer = self.transport.getPeer()
        their_peer = self.otherConn.transport.getPeer()
        print "%s\t%s:%d %s %s:%d\n"%(time.ctime(),
                                        peer.host,peer.port,
                                        ((proto==self and '<') or '>'),
                                        their_peer.host,their_peer.port),
        while data:
            p,data=data[:16],data[16:]
            print string.join(map(lambda x:'%02X'%ord(x),p),' ')+' ',
            print ((16-len(p))*3*' '),
            for c in p:
                if len(repr(c))>3: print '.',
                else: print c,
            print ""
        print ""


class VNCAuthFactory(protocol.Factory):
    """A factory for a VNC auth proxy.
    
    Constructor accepts one argument, a log file name.
    """
    
    def __init__(self, server):
        self.server = server
    
    def buildProtocol(self, addr):
        return VNCAuth(self.server)

