"""
Wrapper for sipb-xen VNC proxying
"""

# twisted imports
from twisted.internet import reactor, protocol, defer
from twisted.python import log

# python imports
import struct
import string
import cPickle
# Python 2.5:
#import hashlib
import sha
import hmac
import base64
import socket
import time
import get_port

TOKEN_KEY = "0M6W0U1IXexThi5idy8mnkqPKEq1LtEnlK/pZSn0cDrN"

def getPort(name, auth):
    port = get_port.findPort(name)
    if port is None:
        return 0
    return int(port.split(':')[1])

class VNCAuthOutgoing(protocol.Protocol):
    
    def __init__(self,socks):
        self.socks=socks

    def connectionMade(self):
        peer = self.transport.getPeer()
        self.socks.makeReply(200)
        self.socks.otherConn=self

    def connectionLost(self, reason):
        self.socks.transport.loseConnection()

    def dataReceived(self,data):
        self.socks.write(data)

    def write(self,data):
        #self.socks.log(self,data)
        self.transport.write(data)


class VNCAuth(protocol.Protocol):
    
    def __init__(self,logging=None,server="localhost"):
        self.logging=logging
        self.server=server
        self.auth=None
    
    def connectionMade(self):
        self.buf=""
        self.otherConn=None

    def validateToken(self, token):
        global TOKEN_KEY
        if token == "quentin":
            self.auth = "quentin@ATHENA.MIT.EDU"
            return #FIXME
        token = base64.urlsafe_b64decode(token)
        token = cPickle.loads(token)
        m = hmac.new(TOKEN_KEY, digestmod=sha)
        m.update(token['data'])
        if (m.digest() == token['digest']):
            data = cPickle.loads(token['data'])
            expires = data["expires"]
            if (time.time() < expires):
                self.auth = data["user"]

    def dataReceived(self,data):
        if self.otherConn:
            self.otherConn.write(data)
            return
        self.buf=self.buf+data
        if ('\r\n\r\n' in self.buf) or ('\n\n' in self.buf) or ('\r\r' in self.buf):
            lines = self.buf.splitlines()
            args = lines.pop(0).split()
            command = args.pop(0)
            headers = {}
            for line in lines:
                try:
                    (header, data) = line.split(": ", 1)
                    headers[header] = data
                except:
                    pass

            if command == "AUTHTOKEN":
                user = args[0]
                token = headers["Auth-token"]
                if token == "1": #FIXME
                    self.auth = user
                    self.makeReply(200, "Authentication successful")
                else:
                    self.makeReply(401)
            elif command == "CONNECTVNC":
                vmname = args[0]
                if ("Auth-token" in headers):
                    token = headers["Auth-token"]
                    try:
                        self.validateToken(token)
                    finally:
                        if self.auth is not None:
                            port = getPort(vmname, self.auth)
                            if port is not None: # FIXME
                                d = self.connectClass(self.server, port, VNCAuthOutgoing, self)
                                d.addErrback(lambda result, self=self: self.makeReply(404, result.getErrorMessage()))
                            else:
                                self.makeReply(401, "Unauthorized to connect to VM "+vmname)
                        else:
                            self.makeReply(401, "Invalid token")
                else:
                    self.makeReply(401, "Login first")
            else:
                self.makeReply(501, "unknown method "+command)
            self.buf=''
        if False and '\000' in self.buf[8:]:
            head,self.buf=self.buf[:8],self.buf[8:]
            try:
                version,code,port=struct.unpack("!BBH",head[:4])
            except struct.error:
                raise RuntimeError, "struct error with head='%s' and buf='%s'"%(repr(head),repr(self.buf))
            user,self.buf=string.split(self.buf,"\000",1)
            if head[4:7]=="\000\000\000": # domain is after
                server,self.buf=string.split(self.buf,'\000',1)
                #server=gethostbyname(server)
            else:
                server=socket.inet_ntoa(head[4:8])
            assert version==4, "Bad version code: %s"%version
            if not self.authorize(code,server,port,user):
                self.makeReply(91)
                return
            if code==1: # CONNECT
                d = self.connectClass(server, port, SOCKSv4Outgoing, self)
                d.addErrback(lambda result, self=self: self.makeReply(91))
            else:
                raise RuntimeError, "Bad Connect Code: %s" % code
            assert self.buf=="","hmm, still stuff in buffer... %s" % repr(self.buf)

    def connectionLost(self, reason):
        if self.otherConn:
            self.otherConn.transport.loseConnection()

    def authorize(self,code,server,port,user):
        log.msg("code %s connection to %s:%s (user %s) authorized" % (code,server,port,user))
        return 1

    def connectClass(self, host, port, klass, *args):
        return protocol.ClientCreator(reactor, klass, *args).connectTCP(host,port)

    def makeReply(self,reply,message=""):
        self.transport.write("VNCProxy/1.0 %d %s\r\n\r\n" % (reply, message))
        if int(reply / 100)!=2: self.transport.loseConnection()

    def write(self,data):
        #self.log(self,data)
        self.transport.write(data)

    def log(self,proto,data):
        if not self.logging: return
        peer = self.transport.getPeer()
        their_peer = self.otherConn.transport.getPeer()
        f=open(self.logging,"a")
        f.write("%s\t%s:%d %s %s:%d\n"%(time.ctime(),
                                        peer.host,peer.port,
                                        ((proto==self and '<') or '>'),
                                        their_peer.host,their_peer.port))
        while data:
            p,data=data[:16],data[16:]
            f.write(string.join(map(lambda x:'%02X'%ord(x),p),' ')+' ')
            f.write((16-len(p))*3*' ')
            for c in p:
                if len(repr(c))>3: f.write('.')
                else: f.write(c)
            f.write('\n')
        f.write('\n')
        f.close()


class VNCAuthFactory(protocol.Factory):
    """A factory for a VNC auth proxy.
    
    Constructor accepts one argument, a log file name.
    """
    
    def __init__(self, log, server):
        self.logging = log
        self.server = server
    
    def buildProtocol(self, addr):
        return VNCAuth(self.logging, self.server)

