#!/usr/bin/python
from twisted.internet import reactor, ssl, protocol, error
from OpenSSL import SSL
import base64, pickle
import getopt, sys, os, time

verbose = False

def usage():
    print """%s [-v] [-l [HOST:]PORT] {-a AUTHTOKEN|VMNAME}
 -l, --listen [HOST:]PORT  port (and optionally host) to listen on for
                           connections (default is 127.0.0.1 and a randomly
                           chosen port). Use an empty HOST to listen on all
                           interfaces (INSECURE!)
 -a, --authtoken AUTHTOKEN Authentication token for connecting to the VNC server
 VMNAME                    VM name to connect to (automatically fetches an
                           authentication token using remctl)
 -v                        verbose status messages""" % (sys.argv[0])

class ClientContextFactory(ssl.ClientContextFactory):

    def _verify(self, connection, x509, errnum, errdepth, ok):
        if verbose:
            print '_verify (ok=%d):' % ok
            print '  subject:', x509.get_subject()
            print '  issuer:', x509.get_issuer()
            print '  errnum %s, errdepth %d' % (errnum, errdepth)
        if errnum == 10:
            print 'The VNC server certificate has expired. Please contact xvm@mit.edu.'
        return ok

    def getContext(self):
        ctx = ssl.ClientContextFactory.getContext(self)

        certFile = '/mit/xvm/vnc/servers.cert'
        if verbose: print "Loading certificates from %s" % certFile
        ctx.load_verify_locations(certFile)
        ctx.set_verify(SSL.VERIFY_PEER|SSL.VERIFY_FAIL_IF_NO_PEER_CERT,
                       self._verify)

        return ctx

class Proxy(protocol.Protocol):
    peer = None

    def setPeer(self, peer):
        self.peer = peer

    def connectionLost(self, reason):
        if self.peer is not None:
            self.peer.transport.loseConnection()
            self.peer = None

    def dataReceived(self, data):
        self.peer.transport.write(data)

class ProxyClient(Proxy):
    ready = False

    def connectionMade(self):
        self.peer.setPeer(self)
        data = "CONNECTVNC %s VNCProxy/1.0\r\nAuth-token: %s\r\n\r\n" % (self.factory.machine, self.factory.authtoken)
        self.transport.write(data)
        if verbose: print "ProxyClient: connection made"
    def dataReceived(self, data):
        if not self.ready:
            if verbose: print 'ProxyClient: received data "%s"' % data
            if data.startswith("VNCProxy/1.0 200 "):
                self.ready = True
                if "\n" in data:
                    self.peer.transport.write(data[data.find("\n")+3:])
                self.peer.transport.resumeProducing() # Allow reading
            else:
                print "Failed to connect: %s" % data
                self.transport.loseConnection()
        else:
            self.peer.transport.write(data)

class ProxyClientFactory(protocol.ClientFactory):
    protocol = ProxyClient
    
    def __init__(self, authtoken, machine):
        self.authtoken = authtoken
        self.machine = machine

    def setServer(self, server):
        self.server = server

    def buildProtocol(self, *args, **kw):
        prot = protocol.ClientFactory.buildProtocol(self, *args, **kw)
        prot.setPeer(self.server)
        return prot

    def clientConnectionFailed(self, connector, reason):
        self.server.transport.loseConnection()


class ProxyServer(Proxy):
    clientProtocolFactory = ProxyClientFactory
    authtoken = None
    machine = None

    def connectionMade(self):
        # Don't read anything from the connecting client until we have
        # somewhere to send it to.
        self.transport.pauseProducing()
        
        if verbose: print "ProxyServer: connection made"

        client = self.clientProtocolFactory(self.factory.authtoken, self.factory.machine)
        client.setServer(self)

        reactor.connectSSL(self.factory.host, self.factory.port, client, ClientContextFactory())
        

class ProxyFactory(protocol.Factory):
    protocol = ProxyServer

    def __init__(self, host, port, authtoken, machine):
        self.host = host
        self.port = port
        self.authtoken = authtoken
        self.machine = machine

def main():
    global verbose
    try:
        opts, args = getopt.gnu_getopt(sys.argv[1:], "hl:a:v",
                                       ["help", "listen=", "authtoken="])
    except getopt.GetoptError, err:
        print str(err) # will print something like "option -a not recognized"
        usage()
        sys.exit(2)
    listen = ["127.0.0.1", None]
    authtoken = None
    for o, a in opts:
        if o == "-v":
            verbose = True
        elif o in ("-h", "--help"):
            usage()
            sys.exit()
        elif o in ("-l", "--listen"):
            if ":" in a:
                listen = a.split(":", 2)
                listen[1] = int(listen[1])
            else:
                listen[1] = int(a)
        elif o in ("-a", "--authtoken"):
            authtoken = a
        else:
            assert False, "unhandled option"

    # Get authentication token
    if authtoken is None:
        # User didn't give us an authentication token, so we need to get one
        if len(args) != 1:
            print "VMNAME not given or too many arguments"
            usage()
            sys.exit(2)
        from subprocess import PIPE, Popen
        try:
            p = Popen(["remctl", "remote", "control", args[0], "vnctoken"],
                      stdout=PIPE)
        except OSError:
            if verbose: print "remctl not found in path. Trying remctl locker."
            p = Popen(["athrun", "remctl", "remctl",
                       "remote", "control", args[0], "vnctoken"],
                      stdout=PIPE)
        authtoken = p.communicate()[0]
        if p.returncode != 0:
            print "Unable to get authentication token"
            sys.exit(1)
        if verbose: print 'Got authentication token "%s" for VM %s' % \
                          (authtoken, args[0])

    # Unpack authentication token
    try:
        token_outer = base64.urlsafe_b64decode(authtoken)
        token_outer = pickle.loads(token_outer)
        token_inner = pickle.loads(token_outer["data"])
        machine = token_inner["machine"]
        connect_host = token_inner["connect_host"]
        connect_port = token_inner["connect_port"]
        token_expires = token_inner["expires"]
        if verbose: print "Unpacked authentication token:\n%s" % \
                          repr(token_inner)
    except:
        print "Invalid authentication token"
        sys.exit(1)
    
    if verbose: print "Will connect to %s:%s" % (connect_host, connect_port) 
    if listen[1] is None:
        listen[1] = 5900
        ready = False
        while not ready and listen[1] < 6000:
            try:
                reactor.listenTCP(listen[1], ProxyFactory(connect_host, connect_port, authtoken, machine), interface=listen[0])
                ready = True
            except error.CannotListenError:
                listen[1] += 1
    else:
        reactor.listenTCP(listen[1], ProxyFactory(connect_host, connect_port, authtoken, machine))
    
    print "Ready to connect. Connect to %s:%s (display %d) now with your VNC client. The password is 'moocow'." % (listen[0], listen[1], listen[1]-5900)
    print "You must connect before your authentication token expires at %s." % \
          (time.ctime(token_expires))
    
    reactor.run()

if '__main__' == __name__:
    main()
