#!/usr/bin/python

import sys
import cgi
import os
import string
import subprocess
import re
import time
import cPickle
import base64
import sha
import hmac
import datetime

sys.stderr = sys.stdout
sys.path.append('/home/ecprice/.local/lib/python2.5/site-packages')

from Cheetah.Template import Template
from sipb_xen_database import *
import random

class MyException(Exception):
    """Base class for my exceptions"""
    pass

class InvalidInput(MyException):
    """Exception for user-provided input is invalid but maybe in good faith.

    This would include setting memory to negative (which might be a
    typo) but not setting an invalid boot CD (which requires bypassing
    the select box).
    """
    pass

class CodeError(MyException):
    """Exception for internal errors or bad faith input."""
    pass



def helppopup(subj):
    """Return HTML code for a (?) link to a specified help topic"""
    return '<span class="helplink"><a href="help?subject='+subj+'&amp;simple=true" target="_blank" onclick="return helppopup(\''+subj+'\')">(?)</a></span>'


global_dict = {}
global_dict['helppopup'] = helppopup


# ... and stolen from xend/uuid.py
def randomUUID():
    """Generate a random UUID."""

    return [ random.randint(0, 255) for _ in range(0, 16) ]

def uuidToString(u):
    """Turn a numeric UUID to a hyphen-seperated one."""
    return "-".join(["%02x" * 4, "%02x" * 2, "%02x" * 2, "%02x" * 2,
                     "%02x" * 6]) % tuple(u)

MAX_MEMORY_TOTAL = 512
MAX_MEMORY_SINGLE = 256
MIN_MEMORY_SINGLE = 16
MAX_DISK_TOTAL = 50
MAX_DISK_SINGLE = 50
MIN_DISK_SINGLE = 0.1
MAX_VMS_TOTAL = 10
MAX_VMS_ACTIVE = 4

def getMachinesByOwner(owner):
    """Return the machines owned by a given owner."""
    return Machine.select_by(owner=owner)

def maxMemory(user, machine=None, on=None):
    """Return the maximum memory for a machine or a user.

    If machine is None, return the memory available for a new 
    machine.  Else, return the maximum that machine can have.

    on is a dictionary from machines to booleans, whether a machine is
    on.  If None, it is recomputed. XXX make this global?
    """

    machines = getMachinesByOwner(user.username)
    if on is None:
        on = getUptimes(machines)
    active_machines = [x for x in machines if on[x]]
    mem_usage = sum([x.memory for x in active_machines if x != machine])
    return min(MAX_MEMORY_SINGLE, MAX_MEMORY_TOTAL-mem_usage)

def maxDisk(user, machine=None):
    machines = getMachinesByOwner(user.username)
    disk_usage = sum([sum([y.size for y in x.disks])
                      for x in machines if x != machine])
    return min(MAX_DISK_SINGLE, MAX_DISK_TOTAL-disk_usage/1024.)

def canAddVm(user, on=None):
    machines = getMachinesByOwner(user.username)
    if on is None:
        on = getUptimes(machines)
    active_machines = [x for x in machines if on[x]]
    return (len(machines) < MAX_VMS_TOTAL and
            len(active_machines) < MAX_VMS_ACTIVE)

def haveAccess(user, machine):
    """Return whether a user has access to a machine"""
    if user.username == 'moo':
        return True
    return machine.owner == user.username

def error(op, user, fields, err):
    """Print an error page when a CodeError occurs"""
    d = dict(op=op, user=user, errorMessage=str(err))
    print Template(file='error.tmpl', searchList=[d, global_dict]);

def validMachineName(name):
    """Check that name is valid for a machine name"""
    if not name:
        return False
    charset = string.ascii_letters + string.digits + '-_'
    if name[0] in '-_' or len(name) > 22:
        return False
    for x in name:
        if x not in charset:
            return False
    return True

def kinit(username = 'tabbott/extra', keytab = '/etc/tabbott.keytab'):
    """Kinit with a given username and keytab"""

    p = subprocess.Popen(['kinit', "-k", "-t", keytab, username],
                         stderr=subprocess.PIPE)
    e = p.wait()
    if e:
        raise CodeError("Error %s in kinit: %s" % (e, p.stderr.read()))

def checkKinit():
    """If we lack tickets, kinit."""
    p = subprocess.Popen(['klist', '-s'])
    if p.wait():
        kinit()

def remctl(*args, **kws):
    """Perform a remctl and return the output.

    kinits if necessary, and outputs errors to stderr.
    """
    checkKinit()
    p = subprocess.Popen(['remctl', 'black-mesa.mit.edu']
                         + list(args),
                         stdout=subprocess.PIPE,
                         stderr=subprocess.PIPE)
    if kws.get('err'):
        p.wait()
        return p.stdout.read(), p.stderr.read()
    if p.wait():
        raise CodeError('ERROR on remctl %s: %s' %
                          (args, p.stderr.read()))
    return p.stdout.read()

def makeDisks():
    """Update the lvm partitions to include all disks in the database."""
    remctl('web', 'lvcreate')

def bootMachine(machine, cdtype):
    """Boot a machine with a given boot CD.

    If cdtype is None, give no boot cd.  Otherwise, it is the string
    id of the CD (e.g. 'gutsy_i386')
    """
    if cdtype is not None:
        remctl('web', 'vmboot', machine.name,
               cdtype)
    else:
        remctl('web', 'vmboot', machine.name)

def registerMachine(machine):
    """Register a machine to be controlled by the web interface"""
    remctl('web', 'register', machine.name)

def unregisterMachine(machine):
    """Unregister a machine to not be controlled by the web interface"""
    remctl('web', 'unregister', machine.name)

def parseStatus(s):
    """Parse a status string into nested tuples of strings.

    s = output of xm list --long <machine_name>
    """
    values = re.split('([()])', s)
    stack = [[]]
    for v in values[2:-2]: #remove initial and final '()'
        if not v:
            continue
        v = v.strip()
        if v == '(':
            stack.append([])
        elif v == ')':
            if len(stack[-1]) == 1:
                stack[-1].append('')
            stack[-2].append(stack[-1])
            stack.pop()
        else:
            if not v:
                continue
            stack[-1].extend(v.split())
    return stack[-1]

def getUptimes(machines):
    """Return a dictionary mapping machine names to uptime strings"""
    value_string = remctl('web', 'listvms')
    lines = value_string.splitlines()
    d = {}
    for line in lines[1:]:
        lst = line.split()
        name, id = lst[:2]
        uptime = ' '.join(lst[2:])
        d[name] = uptime
    ans = {}
    for m in machines:
        ans[m] = d.get(m.name)
    return ans

def statusInfo(machine):
    """Return the status list for a given machine.

    Gets and parses xm list --long
    """
    value_string, err_string = remctl('list-long', machine.name, err=True)
    if 'Unknown command' in err_string:
        raise CodeError("ERROR in remctl list-long %s is not registered" % (machine.name,))
    elif 'does not exist' in err_string:
        return None
    elif err_string:
        raise CodeError("ERROR in remctl list-long %s:  %s" % (machine.name, err_string))
    status = parseStatus(value_string)
    return status

def hasVnc(status):
    """Does the machine with a given status list support VNC?"""
    if status is None:
        return False
    for l in status:
        if l[0] == 'device' and l[1][0] == 'vfb':
            d = dict(l[1][1:])
            return 'location' in d
    return False

def createVm(user, name, memory, disk, is_hvm, cdrom):
    """Create a VM and put it in the database"""
    # put stuff in the table
    transaction = ctx.current.create_transaction()
    try:
        if memory > maxMemory(user):
            raise InvalidInput("Too much memory requested")
        if disk > maxDisk(user) * 1024:
            raise InvalidInput("Too much disk requested")
        if not canAddVm(user):
            raise InvalidInput("Too many VMs requested")
        res = meta.engine.execute('select nextval(\'"machines_machine_id_seq"\')')
        id = res.fetchone()[0]
        machine = Machine()
        machine.machine_id = id
        machine.name = name
        machine.memory = memory
        machine.owner = user.username
        machine.contact = user.email
        machine.uuid = uuidToString(randomUUID())
        machine.boot_off_cd = True
        machine_type = Type.get_by(hvm=is_hvm)
        machine.type_id = machine_type.type_id
        ctx.current.save(machine)
        disk = Disk(machine.machine_id, 
                    'hda', disk)
        open = NIC.select_by(machine_id=None)
        if not open: #No IPs left!
            raise CodeError("No IP addresses left!  Contact sipb-xen-dev@mit.edu")
        nic = open[0]
        nic.machine_id = machine.machine_id
        nic.hostname = name
        ctx.current.save(nic)    
        ctx.current.save(disk)
        transaction.commit()
    except:
        transaction.rollback()
        raise
    registerMachine(machine)
    makeDisks()
    # tell it to boot with cdrom
    bootMachine(machine, cdrom)

    return machine

def validMemory(user, memory, machine=None):
    """Parse and validate limits for memory for a given user and machine."""
    try:
        memory = int(memory)
        if memory < MIN_MEMORY_SINGLE:
            raise ValueError
    except ValueError:
        raise InvalidInput("Invalid memory amount; must be at least %s MB" %
                          MIN_MEMORY_SINGLE)
    if memory > maxMemory(user, machine):
        raise InvalidInput("Too much memory requested")
    return memory

def validDisk(user, disk, machine=None):
    """Parse and validate limits for disk for a given user and machine."""
    try:
        disk = float(disk)
        if disk > maxDisk(user, machine):
            raise InvalidInput("Too much disk requested")
        disk = int(disk * 1024)
        if disk < MIN_DISK_SINGLE * 1024:
            raise ValueError
    except ValueError:
        raise InvalidInput("Invalid disk amount; minimum is %s GB" %
                          MIN_DISK_SINGLE)
    return disk

def create(user, fields):
    """Handler for create requests."""
    name = fields.getfirst('name')
    if not validMachineName(name):
        raise InvalidInput("Invalid name '%s'" % name)
    name = user.username + '_' + name.lower()

    if Machine.get_by(name=name):
        raise InvalidInput("A machine named '%s' already exists" % name)
    
    memory = fields.getfirst('memory')
    memory = validMemory(user, memory)
    
    disk = fields.getfirst('disk')
    disk = validDisk(user, disk)

    vm_type = fields.getfirst('vmtype')
    if vm_type not in ('hvm', 'paravm'):
        raise CodeError("Invalid vm type '%s'"  % vm_type)    
    is_hvm = (vm_type == 'hvm')

    cdrom = fields.getfirst('cdrom')
    if cdrom is not None and not CDROM.get(cdrom):
        raise CodeError("Invalid cdrom type '%s'" % cdrom)    
    
    machine = createVm(user, name, memory, disk, is_hvm, cdrom)
    d = dict(user=user,
             machine=machine)
    print Template(file='create.tmpl',
                   searchList=[d, global_dict]);

def listVms(user, fields):
    """Handler for list requests."""
    machines = [m for m in Machine.select() if haveAccess(user, m)]    
    on = {}
    has_vnc = {}
    uptimes = getUptimes(machines)
    on = uptimes
    for m in machines:
        if not on[m]:
            has_vnc[m] = 'Off'
        elif m.type.hvm:
            has_vnc[m] = True
        else:
            has_vnc[m] = "ParaVM"+helppopup("paravm_console")
    #     for m in machines:
    #         status = statusInfo(m)
    #         on[m.name] = status is not None
    #         has_vnc[m.name] = hasVnc(status)
    max_mem=maxMemory(user, on=on)
    max_disk=maxDisk(user)
    d = dict(user=user,
             can_add_vm=canAddVm(user, on=on),
             max_mem=max_mem,
             max_disk=max_disk,
             default_mem=max_mem,
             default_disk=min(4.0, max_disk),
             machines=machines,
             has_vnc=has_vnc,
             uptimes=uptimes,
             cdroms=CDROM.select())
    print Template(file='list.tmpl', searchList=[d, global_dict])

def testMachineId(user, machineId, exists=True):
    """Parse, validate and check authorization for a given machineId.

    If exists is False, don't check that it exists.
    """
    if machineId is None:
        raise CodeError("No machine ID specified")
    try:
        machineId = int(machineId)
    except ValueError:
        raise CodeError("Invalid machine ID '%s'" % machineId)
    machine = Machine.get(machineId)
    if exists and machine is None:
        raise CodeError("No such machine ID '%s'" % machineId)
    if machine is not None and not haveAccess(user, machine):
        raise CodeError("No access to machine ID '%s'" % machineId)
    return machine

def vnc(user, fields):
    """VNC applet page.

    Note that due to same-domain restrictions, the applet connects to
    the webserver, which needs to forward those requests to the xen
    server.  The Xen server runs another proxy that (1) authenticates
    and (2) finds the correct port for the VM.

    You might want iptables like:

    -t nat -A PREROUTING -s ! 18.181.0.60 -i eth1 -p tcp -m tcp --dport 10003 -j DNAT --to-destination 18.181.0.60:10003 
    -t nat -A POSTROUTING -d 18.181.0.60 -o eth1 -p tcp -m tcp --dport 10003 -j SNAT --to-source 18.187.7.142 
    -A FORWARD -d 18.181.0.60 -i eth1 -o eth1 -p tcp -m tcp --dport 10003 -j ACCEPT

    Remember to enable iptables!
    echo 1 > /proc/sys/net/ipv4/ip_forward
    """
    machine = testMachineId(user, fields.getfirst('machine_id'))
    
    TOKEN_KEY = "0M6W0U1IXexThi5idy8mnkqPKEq1LtEnlK/pZSn0cDrN"

    data = {}
    data["user"] = user.username
    data["machine"]=machine.name
    data["expires"]=time.time()+(5*60)
    pickledData = cPickle.dumps(data)
    m = hmac.new(TOKEN_KEY, digestmod=sha)
    m.update(pickledData)
    token = {'data': pickledData, 'digest': m.digest()}
    token = cPickle.dumps(token)
    token = base64.urlsafe_b64encode(token)
    
    d = dict(user=user,
             machine=machine,
             hostname=os.environ.get('SERVER_NAME', 'localhost'),
             authtoken=token)
    print Template(file='vnc.tmpl',
                   searchList=[d, global_dict])

def getNicInfo(data_dict, machine):
    """Helper function for info, get data on nics for a machine.

    Modifies data_dict to include the relevant data, and returns a list
    of (key, name) pairs to display "name: data_dict[key]" to the user.
    """
    data_dict['num_nics'] = len(machine.nics)
    nic_fields_template = [('nic%s_hostname', 'NIC %s hostname'),
                           ('nic%s_mac', 'NIC %s MAC Addr'),
                           ('nic%s_ip', 'NIC %s IP'),
                           ]
    nic_fields = []
    for i in range(len(machine.nics)):
        nic_fields.extend([(x % i, y % i) for x, y in nic_fields_template])
        data_dict['nic%s_hostname' % i] = machine.nics[i].hostname + '.servers.csail.mit.edu'
        data_dict['nic%s_mac' % i] = machine.nics[i].mac_addr
        data_dict['nic%s_ip' % i] = machine.nics[i].ip
    if len(machine.nics) == 1:
        nic_fields = [(x, y.replace('NIC 0 ', '')) for x, y in nic_fields]
    return nic_fields

def getDiskInfo(data_dict, machine):
    """Helper function for info, get data on disks for a machine.

    Modifies data_dict to include the relevant data, and returns a list
    of (key, name) pairs to display "name: data_dict[key]" to the user.
    """
    data_dict['num_disks'] = len(machine.disks)
    disk_fields_template = [('%s_size', '%s size')]
    disk_fields = []
    for disk in machine.disks:
        name = disk.guest_device_name
        disk_fields.extend([(x % name, y % name) for x, y in disk_fields_template])
        data_dict['%s_size' % name] = "%0.1f GB" % (disk.size / 1024.)
    return disk_fields

def deleteVM(machine):
    """Delete a VM."""
    transaction = ctx.current.create_transaction()
    delete_disk_pairs = [(machine.name, d.guest_device_name) for d in machine.disks]
    try:
        for nic in machine.nics:
            nic.machine_id = None
            nic.hostname = None
            ctx.current.save(nic)
        for disk in machine.disks:
            ctx.current.delete(disk)
        ctx.current.delete(machine)
        transaction.commit()
    except:
        transaction.rollback()
        raise
    for mname, dname in delete_disk_pairs:
        remctl('web', 'lvremove', mname, dname)
    unregisterMachine(machine)

def command(user, fields):
    """Handler for running commands like boot and delete on a VM."""
    print time.time()-start_time
    machine = testMachineId(user, fields.getfirst('machine_id'))
    action = fields.getfirst('action')
    cdrom = fields.getfirst('cdrom')
    print time.time()-start_time
    if cdrom is not None and not CDROM.get(cdrom):
        raise CodeError("Invalid cdrom type '%s'" % cdrom)    
    if action not in ('Reboot', 'Power on', 'Power off', 'Shutdown', 'Delete VM'):
        raise CodeError("Invalid action '%s'" % action)
    if action == 'Reboot':
        if cdrom is not None:
            remctl('reboot', machine.name, cdrom)
        else:
            remctl('reboot', machine.name)
    elif action == 'Power on':
        if maxMemory(user) < machine.memory:
            raise InvalidInput("You don't have enough free RAM quota")
        bootMachine(machine, cdrom)
    elif action == 'Power off':
        remctl('destroy', machine.name)
    elif action == 'Shutdown':
        remctl('shutdown', machine.name)
    elif action == 'Delete VM':
        deleteVM(machine)
    print time.time()-start_time

    d = dict(user=user,
             command=action,
             machine=machine)
    print Template(file="command.tmpl", searchList=[d, global_dict])
        
def modify(user, fields):
    """Handler for modifying attributes of a machine."""
    #XXX not written yet
    machine = testMachineId(user, fields.getfirst('machine_id'))
    
def help(user, fields):
    """Handler for help messages."""
    simple = fields.getfirst('simple')
    subjects = fields.getlist('subject')
    
    mapping = dict(paravm_console="""
ParaVM machines do not support console access over VNC.  To access
these machines, you either need to boot with a liveCD and ssh in or
hope that the sipb-xen maintainers add support for serial consoles.""",
                   hvm_paravm="""
HVM machines use the virtualization features of the processor, while
ParaVM machines use Xen's emulation of virtualization features.  You
want an HVM virtualized machine.""",
                   cpu_weight="""Don't ask us!  We're as mystified as you are.""")
    
    d = dict(user=user,
             simple=simple,
             subjects=subjects,
             mapping=mapping)
    
    print Template(file="help.tmpl", searchList=[d, global_dict])
    

def info(user, fields):
    """Handler for info on a single VM."""
    machine = testMachineId(user, fields.getfirst('machine_id'))
    status = statusInfo(machine)
    has_vnc = hasVnc(status)
    if status is None:
        main_status = dict(name=machine.name,
                           memory=str(machine.memory))
    else:
        main_status = dict(status[1:])
    start_time = float(main_status.get('start_time', 0))
    uptime = datetime.timedelta(seconds=int(time.time()-start_time))
    cpu_time_float = float(main_status.get('cpu_time', 0))
    cputime = datetime.timedelta(seconds=int(cpu_time_float))
    display_fields = """name uptime memory state cpu_weight on_reboot 
     on_poweroff on_crash on_xend_start on_xend_stop bootloader""".split()
    display_fields = [('name', 'Name'),
                      ('owner', 'Owner'),
                      ('contact', 'Contact'),
                      ('type', 'Type'),
                      'NIC_INFO',
                      ('uptime', 'uptime'),
                      ('cputime', 'CPU usage'),
                      ('memory', 'RAM'),
                      'DISK_INFO',
                      ('state', 'state (xen format)'),
                      ('cpu_weight', 'CPU weight'+helppopup('cpu_weight')),
                      ('on_reboot', 'Action on VM reboot'),
                      ('on_poweroff', 'Action on VM poweroff'),
                      ('on_crash', 'Action on VM crash'),
                      ('on_xend_start', 'Action on Xen start'),
                      ('on_xend_stop', 'Action on Xen stop'),
                      ('bootloader', 'Bootloader options'),
                      ]
    fields = []
    machine_info = {}
    machine_info['type'] = machine.type.hvm and 'HVM' or 'ParaVM'
    machine_info['owner'] = machine.owner
    machine_info['contact'] = machine.contact

    nic_fields = getNicInfo(machine_info, machine)
    nic_point = display_fields.index('NIC_INFO')
    display_fields = display_fields[:nic_point] + nic_fields + display_fields[nic_point+1:]

    disk_fields = getDiskInfo(machine_info, machine)
    disk_point = display_fields.index('DISK_INFO')
    display_fields = display_fields[:disk_point] + disk_fields + display_fields[disk_point+1:]
    
    main_status['memory'] += ' MB'
    for field, disp in display_fields:
        if field in ('uptime', 'cputime'):
            fields.append((disp, locals()[field]))
        elif field in main_status:
            fields.append((disp, main_status[field]))
        elif field in machine_info:
            fields.append((disp, machine_info[field]))
        else:
            pass
            #fields.append((disp, None))
    max_mem = maxMemory(user, machine)
    max_disk = maxDisk(user, machine)
    d = dict(user=user,
             cdroms=CDROM.select(),
             on=status is not None,
             machine=machine,
             has_vnc=has_vnc,
             uptime=str(uptime),
             ram=machine.memory,
             max_mem=max_mem,
             max_disk=max_disk,
             fields = fields)
    print Template(file='info.tmpl',
                   searchList=[d, global_dict])

mapping = dict(list=listVms,
               vnc=vnc,
               command=command,
               modify=modify,
               info=info,
               create=create,
               help=help)

if __name__ == '__main__':
    start_time = time.time()
    fields = cgi.FieldStorage()
    class User:
        username = "moo"
        email = 'moo@cow.com'
    u = User()
    if 'SSL_CLIENT_S_DN_Email' in os.environ:
        username = os.environ[ 'SSL_CLIENT_S_DN_Email'].split("@")[0]
        u.username = username
        u.email = os.environ[ 'SSL_CLIENT_S_DN_Email']
    else:
        u.username = 'moo'
        u.email = 'nobody'
    connect('postgres://sipb-xen@sipb-xen-dev/sipb_xen')
    operation = os.environ.get('PATH_INFO', '')
    #print 'Content-Type: text/plain\n'
    #print operation
    if not operation:
        print "Status: 301 Moved Permanently"
        print 'Location: ' + os.environ['SCRIPT_NAME']+'/\n'
        sys.exit(0)
    print 'Content-Type: text/html\n'

    if operation.startswith('/'):
        operation = operation[1:]
    if not operation:
        operation = 'list'
    
    fun = mapping.get(operation, 
                      lambda u, e:
                          error(operation, u, e,
                                "Invalid operation '%s'" % operation))
    if fun not in (help, ):
        connect('postgres://sipb-xen@sipb-xen-dev/sipb_xen')
    try:
        fun(u, fields)
    except CodeError, err:
        error(operation, u, fields, err)
    except InvalidInput, err:
        error(operation, u, fields, err)
