#!/usr/bin/python

import sys
import cgi
import os
import string
import subprocess
import re
import time
import cPickle
import base64
import sha
import hmac

print 'Content-Type: text/html\n'
sys.stderr = sys.stdout
sys.path.append('/home/ecprice/.local/lib/python2.5/site-packages')

from Cheetah.Template import Template
from sipb_xen_database import *
import random

class MyException(Exception):
    pass

# ... and stolen from xend/uuid.py
def randomUUID():
    """Generate a random UUID."""

    return [ random.randint(0, 255) for _ in range(0, 16) ]

def uuidToString(u):
    return "-".join(["%02x" * 4, "%02x" * 2, "%02x" * 2, "%02x" * 2,
                     "%02x" * 6]) % tuple(u)

def maxMemory(user):
    return 256

def maxDisk(user):
    return 10.0

def haveAccess(user, machine):
    return True

def error(op, user, fields, err):
    d = dict(op=op, user=user, errorMessage=str(err))
    print Template(file='error.tmpl', searchList=d);

def validMachineName(name):
    """Check that name is valid for a machine name"""
    if not name:
        return False
    charset = string.ascii_letters + string.digits + '-_'
    if name[0] in '-_' or len(name) > 22:
        return False
    return all(x in charset for x in name)

def kinit(username = 'tabbott/extra', keytab = '/etc/tabbott.keytab'):
    """Kinit with a given username and keytab"""

    p = subprocess.Popen(['kinit', "-k", "-t", keytab, username])
    e = p.wait()
    if e:
        raise MyException("Error %s in kinit" % e)

def checkKinit():
    """If we lack tickets, kinit."""
    p = subprocess.Popen(['klist', '-s'])
    if p.wait():
        kinit()

def remctl(*args, **kws):
    """Perform a remctl and return the output.

    kinits if necessary, and outputs errors to stderr.
    """
    checkKinit()
    p = subprocess.Popen(['remctl', 'black-mesa.mit.edu']
                         + list(args),
                         stdout=subprocess.PIPE,
                         stderr=subprocess.PIPE)
    if kws.get('err'):
        return p.stdout.read(), p.stderr.read()
    if p.wait():
        print >> sys.stderr, 'ERROR on remctl ', args
        print >> sys.stderr, p.stderr.read()
    return p.stdout.read()

def makeDisks():
    """Update the lvm partitions to include all disks in the database."""
    remctl('web', 'lvcreate')

def bootMachine(machine, cdtype):
    """Boot a machine with a given boot CD.

    If cdtype is None, give no boot cd.  Otherwise, it is the string
    id of the CD (e.g. 'gutsy_i386')
    """
    if cdtype is not None:
        remctl('web', 'vmboot', machine.name,
               cdtype)
    else:
        remctl('web', 'vmboot', machine.name)

def registerMachine(machine):
    """Register a machine to be controlled by the web interface"""
    remctl('web', 'register', machine.name)

def parseStatus(s):
    """Parse a status string into nested tuples of strings.

    s = output of xm list --long <machine_name>
    """
    values = re.split('([()])', s)
    stack = [[]]
    for v in values[2:-2]: #remove initial and final '()'
        if not v:
            continue
        v = v.strip()
        if v == '(':
            stack.append([])
        elif v == ')':
            stack[-2].append(stack[-1])
            stack.pop()
        else:
            if not v:
                continue
            stack[-1].extend(v.split())
    return stack[-1]

def statusInfo(machine):
    value_string, err_string = remctl('list-long', machine.name, err=True)
    if 'Unknown command' in err_string:
        raise MyException("ERROR in remctl list-long %s is not registered" % (machine.name,))
    elif 'does not exist' in err_string:
        return None
    elif err_string:
        raise MyException("ERROR in remctl list-long %s:  %s" % (machine.name, err_string))
    status = parseStatus(value_string)
    return status

def hasVnc(status):
    if status is None:
        return False
    for l in status:
        if l[0] == 'device' and l[1][0] == 'vfb':
            d = dict(l[1][1:])
            return 'location' in d
    return False

def createVm(user, name, memory, disk, is_hvm, cdrom):
    # put stuff in the table
    transaction = ctx.current.create_transaction()
    try:
        res = meta.engine.execute('select nextval(\'"machines_machine_id_seq"\')')
        id = res.fetchone()[0]
        machine = Machine()
        machine.machine_id = id
        machine.name = name
        machine.memory = memory
        machine.owner = user.username
        machine.contact = user.email
        machine.uuid = uuidToString(randomUUID())
        machine.boot_off_cd = True
        machine_type = Type.get_by(hvm=is_hvm)
        machine.type_id = machine_type.type_id
        ctx.current.save(machine)
        disk = Disk(machine.machine_id, 
                    'hda', disk)
        open = NIC.select_by(machine_id=None)
        if not open: #No IPs left!
            return "No IP addresses left!  Contact sipb-xen-dev@mit.edu"
        nic = open[0]
        nic.machine_id = machine.machine_id
        nic.hostname = name
        ctx.current.save(nic)    
        ctx.current.save(disk)
        transaction.commit()
    except:
        transaction.rollback()
        raise
    makeDisks()
    registerMachine(machine)
    # tell it to boot with cdrom
    bootMachine(machine, cdrom)

    return machine

def create(user, fields):
    name = fields.getfirst('name')
    if not validMachineName(name):
        raise MyException("Invalid name '%s'" % name)
    name = name.lower()

    if Machine.get_by(name=name):
        raise MyException("A machine named '%s' already exists" % name)
    
    memory = fields.getfirst('memory')
    try:
        memory = int(memory)
        if memory <= 0:
            raise ValueError
    except ValueError:
        raise MyException("Invalid memory amount")
    if memory > maxMemory(user):
        raise MyException("Too much memory requested")
    
    disk = fields.getfirst('disk')
    try:
        disk = float(disk)
        disk = int(disk * 1024)
        if disk <= 0:
            raise ValueError
    except ValueError:
        raise MyException("Invalid disk amount")
    if disk > maxDisk(user):
        raise MyException("Too much disk requested")
    
    vm_type = fields.getfirst('vmtype')
    if vm_type not in ('hvm', 'paravm'):
        raise MyException("Invalid vm type '%s'"  % vm_type)    
    is_hvm = (vm_type == 'hvm')

    cdrom = fields.getfirst('cdrom')
    if cdrom is not None and not CDROM.get(cdrom):
        raise MyException("Invalid cdrom type '%s'" % cdrom)    
    
    machine = createVm(user, name, memory, disk, is_hvm, cdrom)
    if isinstance(machine, basestring):
        raise MyException(machine)
    d = dict(user=user,
             machine=machine)
    print Template(file='create.tmpl',
                   searchList=d);

def listVms(user, fields):
    machines = Machine.select()
    status = statusInfo(machines)
    has_vnc = {}
    for m in machines:
        on[m.name] = status[m.name] is not None
        has_vnc[m.name] = hasVnc(status[m.name])
    d = dict(user=user,
             maxmem=maxMemory(user),
             maxdisk=maxDisk(user),
             machines=machines,
             status=status,
             has_vnc=has_vnc,
             cdroms=CDROM.select())
    print Template(file='list.tmpl', searchList=d)

def testMachineId(user, machineId, exists=True):
    if machineId is None:
        raise MyException("No machine ID specified")
    try:
        machineId = int(machineId)
    except ValueError:
        raise MyException("Invalid machine ID '%s'" % machineId)
    machine = Machine.get(machineId)
    if exists and machine is None:
        raise MyException("No such machine ID '%s'" % machineId)
    if not haveAccess(user, machine):
        raise MyException("No access to machine ID '%s'" % machineId)
    return machine

def vnc(user, fields):
    """VNC applet page.

    Note that due to same-domain restrictions, the applet connects to
    the webserver, which needs to forward those requests to the xen
    server.  The Xen server runs another proxy that (1) authenticates
    and (2) finds the correct port for the VM.

    You might want iptables like:

    -t nat -A PREROUTING -s ! 18.181.0.60 -i eth1 -p tcp -m tcp --dport 10003 -j DNAT --to-destination 18.181.0.60:10003 
    -t nat -A POSTROUTING -d 18.181.0.60 -o eth1 -p tcp -m tcp --dport 10003 -j SNAT --to-source 18.187.7.142 
    -A FORWARD -d 18.181.0.60 -i eth1 -o eth1 -p tcp -m tcp --dport 10003 -j ACCEPT
    """
    machine = testMachineId(user, fields.getfirst('machine_id'))
    #XXX fix
    
    TOKEN_KEY = "0M6W0U1IXexThi5idy8mnkqPKEq1LtEnlK/pZSn0cDrN"

    data = {}
    data["user"] = user
    data["machine"]=machine
    data["expires"]=time.time()+(5*60)
    pickledData = cPickle.dumps(data)
    m = hmac.new(TOKEN_KEY, digestmod=sha)
    m.update(pickledData)
    token = {'data': pickledData, 'digest': m.digest()}
    token = cPickle.dumps(token)
    token = base64.urlsafe_b64encode(token)
    
    d = dict(user=user,
             machine=machine,
             hostname=os.environ.get('SERVER_NAME', 'localhost'),
             authtoken=token)
    print Template(file='vnc.tmpl',
                   searchList=d)

def info(user, fields):
    machine = testMachineId(user, fields.getfirst('machine_id'))
    d = dict(user=user,
             machine=machine)
    print Template(file='info.tmpl',
                   searchList=d)

mapping = dict(list=listVms,
               vnc=vnc,
               info=info,
               create=create)

if __name__ == '__main__':
    fields = cgi.FieldStorage()
    class C:
        username = "moo"
        email = 'moo@cow.com'
    u = C()
    connect('postgres://sipb-xen@sipb-xen-dev/sipb_xen')
    operation = os.environ.get('PATH_INFO', '')
    if not operation:
        pass
        #XXX do redirect

    if operation.startswith('/'):
        operation = operation[1:]
    if not operation:
        operation = 'list'
    
    fun = mapping.get(operation, 
                      lambda u, e:
                          error(operation, u, e,
                                "Invalid operation '%'" % operation))
    try:
        fun(u, fields)
    except MyException, err:
        error(operation, u, fields, err)
