#!/usr/bin/python
"""Main CGI script for web interface"""

import base64
import cPickle
import cgi
import datetime
import hmac
import os
import sha
import simplejson
import sys
import time
from StringIO import StringIO

def revertStandardError():
    """Move stderr to stdout, and return the contents of the old stderr."""
    errio = sys.stderr
    if not isinstance(errio, StringIO):
        return None
    sys.stderr = sys.stdout
    errio.seek(0)
    return errio.read()

def printError():
    """Revert stderr to stdout, and print the contents of stderr"""
    if isinstance(sys.stderr, StringIO):
        print revertStandardError()

if __name__ == '__main__':
    import atexit
    atexit.register(printError)
    sys.stderr = StringIO()

sys.path.append('/home/ecprice/.local/lib/python2.5/site-packages')

import templates
from Cheetah.Template import Template
from sipb_xen_database import Machine, CDROM, ctx, connect, MachineAccess
import validation
from webcommon import InvalidInput, CodeError, g
import controls

class Checkpoint:
    def __init__(self):
        self.start_time = time.time()
        self.checkpoints = []

    def checkpoint(self, s):
        self.checkpoints.append((s, time.time()))

    def __str__(self):
        return ('Timing info:\n%s\n' %
                '\n'.join(['%s: %s' % (d, t - self.start_time) for
                           (d, t) in self.checkpoints]))

checkpoint = Checkpoint()


def helppopup(subj):
    """Return HTML code for a (?) link to a specified help topic"""
    return ('<span class="helplink"><a href="help?subject=' + subj + 
            '&amp;simple=true" target="_blank" ' + 
            'onclick="return helppopup(\'' + subj + '\')">(?)</a></span>')

def makeErrorPre(old, addition):
    if addition is None:
        return
    if old:
        return old[:-6]  + '\n----\n' + str(addition) + '</pre>'
    else:
        return '<p>STDERR:</p><pre>' + str(addition) + '</pre>'

Template.helppopup = staticmethod(helppopup)
Template.err = None

class JsonDict:
    """Class to store a dictionary that will be converted to JSON"""
    def __init__(self, **kws):
        self.data = kws
        if 'err' in kws:
            err = kws['err']
            del kws['err']
            self.addError(err)

    def __str__(self):
        return simplejson.dumps(self.data)

    def addError(self, text):
        """Add stderr text to be displayed on the website."""
        self.data['err'] = \
            makeErrorPre(self.data.get('err'), text)

class Defaults:
    """Class to store default values for fields."""
    memory = 256
    disk = 4.0
    cdrom = ''
    name = ''
    vmtype = 'hvm'
    def __init__(self, max_memory=None, max_disk=None, **kws):
        if max_memory is not None:
            self.memory = min(self.memory, max_memory)
        if max_disk is not None:
            self.max_disk = min(self.disk, max_disk)
        for key in kws:
            setattr(self, key, kws[key])



DEFAULT_HEADERS = {'Content-Type': 'text/html'}

def error(op, user, fields, err, emsg):
    """Print an error page when a CodeError occurs"""
    d = dict(op=op, user=user, errorMessage=str(err),
             stderr=emsg)
    return templates.error(searchList=[d])

def invalidInput(op, user, fields, err, emsg):
    """Print an error page when an InvalidInput exception occurs"""
    d = dict(op=op, user=user, err_field=err.err_field,
             err_value=str(err.err_value), stderr=emsg,
             errorMessage=str(err))
    return templates.invalid(searchList=[d])

def hasVnc(status):
    """Does the machine with a given status list support VNC?"""
    if status is None:
        return False
    for l in status:
        if l[0] == 'device' and l[1][0] == 'vfb':
            d = dict(l[1][1:])
            return 'location' in d
    return False

def parseCreate(user, fields):
    name = fields.getfirst('name')
    if not validation.validMachineName(name):
        raise InvalidInput('name', name, 'You must provide a machine name.')
    name = name.lower()

    if Machine.get_by(name=name):
        raise InvalidInput('name', name,
                           "Name already exists.")
    
    owner = validation.testOwner(user, fields.getfirst('owner'))

    memory = fields.getfirst('memory')
    memory = validation.validMemory(owner, memory, on=True)
    
    disk_size = fields.getfirst('disk')
    disk_size = validation.validDisk(owner, disk_size)

    vm_type = fields.getfirst('vmtype')
    if vm_type not in ('hvm', 'paravm'):
        raise CodeError("Invalid vm type '%s'"  % vm_type)    
    is_hvm = (vm_type == 'hvm')

    cdrom = fields.getfirst('cdrom')
    if cdrom is not None and not CDROM.get(cdrom):
        raise CodeError("Invalid cdrom type '%s'" % cdrom)

    clone_from = fields.getfirst('clone_from')
    if clone_from and clone_from != 'ice3':
        raise CodeError("Invalid clone image '%s'" % clone_from)
    
    return dict(contact=user, name=name, memory=memory, disk_size=disk_size,
                owner=owner, is_hvm=is_hvm, cdrom=cdrom, clone_from=clone_from)

def create(user, fields):
    """Handler for create requests."""
    try:
        parsed_fields = parseCreate(user, fields)
        machine = controls.createVm(**parsed_fields)
    except InvalidInput, err:
        pass
    else:
        err = None
    g.clear() #Changed global state
    d = getListDict(user)
    d['err'] = err
    if err:
        for field in fields.keys():
            setattr(d['defaults'], field, fields.getfirst(field))
    else:
        d['new_machine'] = parsed_fields['name']
    return templates.list(searchList=[d])


def getListDict(user):
    machines = g.machines
    checkpoint.checkpoint('Got my machines')
    on = {}
    has_vnc = {}
    on = g.uptimes
    checkpoint.checkpoint('Got uptimes')
    for m in machines:
        m.uptime = g.uptimes.get(m)
        if not on[m]:
            has_vnc[m] = 'Off'
        elif m.type.hvm:
            has_vnc[m] = True
        else:
            has_vnc[m] = "ParaVM"+helppopup("paravm_console")
    max_memory = validation.maxMemory(user)
    max_disk = validation.maxDisk(user)
    checkpoint.checkpoint('Got max mem/disk')
    defaults = Defaults(max_memory=max_memory,
                        max_disk=max_disk,
                        owner=user,
                        cdrom='gutsy-i386')
    checkpoint.checkpoint('Got defaults')
    d = dict(user=user,
             cant_add_vm=validation.cantAddVm(user),
             max_memory=max_memory,
             max_disk=max_disk,
             defaults=defaults,
             machines=machines,
             has_vnc=has_vnc,
             uptimes=g.uptimes,
             cdroms=CDROM.select())
    return d

def listVms(user, fields):
    """Handler for list requests."""
    checkpoint.checkpoint('Getting list dict')
    d = getListDict(user)
    checkpoint.checkpoint('Got list dict')
    return templates.list(searchList=[d])
            
def vnc(user, fields):
    """VNC applet page.

    Note that due to same-domain restrictions, the applet connects to
    the webserver, which needs to forward those requests to the xen
    server.  The Xen server runs another proxy that (1) authenticates
    and (2) finds the correct port for the VM.

    You might want iptables like:

    -t nat -A PREROUTING -s ! 18.181.0.60 -i eth1 -p tcp -m tcp \
      --dport 10003 -j DNAT --to-destination 18.181.0.60:10003 
    -t nat -A POSTROUTING -d 18.181.0.60 -o eth1 -p tcp -m tcp \
      --dport 10003 -j SNAT --to-source 18.187.7.142 
    -A FORWARD -d 18.181.0.60 -i eth1 -o eth1 -p tcp -m tcp \
      --dport 10003 -j ACCEPT

    Remember to enable iptables!
    echo 1 > /proc/sys/net/ipv4/ip_forward
    """
    machine = validation.testMachineId(user, fields.getfirst('machine_id'))
    
    TOKEN_KEY = "0M6W0U1IXexThi5idy8mnkqPKEq1LtEnlK/pZSn0cDrN"

    data = {}
    data["user"] = user
    data["machine"] = machine.name
    data["expires"] = time.time()+(5*60)
    pickled_data = cPickle.dumps(data)
    m = hmac.new(TOKEN_KEY, digestmod=sha)
    m.update(pickled_data)
    token = {'data': pickled_data, 'digest': m.digest()}
    token = cPickle.dumps(token)
    token = base64.urlsafe_b64encode(token)
    
    status = controls.statusInfo(machine)
    has_vnc = hasVnc(status)
    
    d = dict(user=user,
             on=status,
             has_vnc=has_vnc,
             machine=machine,
             hostname=os.environ.get('SERVER_NAME', 'localhost'),
             authtoken=token)
    return templates.vnc(searchList=[d])

def getHostname(nic):
    if nic.hostname and '.' in nic.hostname:
        return nic.hostname
    elif nic.machine:
        return nic.machine.name + '.servers.csail.mit.edu'
    else:
        return None


def getNicInfo(data_dict, machine):
    """Helper function for info, get data on nics for a machine.

    Modifies data_dict to include the relevant data, and returns a list
    of (key, name) pairs to display "name: data_dict[key]" to the user.
    """
    data_dict['num_nics'] = len(machine.nics)
    nic_fields_template = [('nic%s_hostname', 'NIC %s Hostname'),
                           ('nic%s_mac', 'NIC %s MAC Addr'),
                           ('nic%s_ip', 'NIC %s IP'),
                           ]
    nic_fields = []
    for i in range(len(machine.nics)):
        nic_fields.extend([(x % i, y % i) for x, y in nic_fields_template])
        if not i:
            data_dict['nic%s_hostname' % i] = getHostname(machine.nics[i])
        data_dict['nic%s_mac' % i] = machine.nics[i].mac_addr
        data_dict['nic%s_ip' % i] = machine.nics[i].ip
    if len(machine.nics) == 1:
        nic_fields = [(x, y.replace('NIC 0 ', '')) for x, y in nic_fields]
    return nic_fields

def getDiskInfo(data_dict, machine):
    """Helper function for info, get data on disks for a machine.

    Modifies data_dict to include the relevant data, and returns a list
    of (key, name) pairs to display "name: data_dict[key]" to the user.
    """
    data_dict['num_disks'] = len(machine.disks)
    disk_fields_template = [('%s_size', '%s size')]
    disk_fields = []
    for disk in machine.disks:
        name = disk.guest_device_name
        disk_fields.extend([(x % name, y % name) for x, y in 
                            disk_fields_template])
        data_dict['%s_size' % name] = "%0.1f GiB" % (disk.size / 1024.)
    return disk_fields

def command(user, fields):
    """Handler for running commands like boot and delete on a VM."""
    back = fields.getfirst('back')
    try:
        d = controls.commandResult(user, fields)
        if d['command'] == 'Delete VM':
            back = 'list'
    except InvalidInput, err:
        if not back:
            raise
        #print >> sys.stderr, err
        result = err
    else:
        result = 'Success!'
        if not back:
            return templates.command(searchList=[d])
    if back == 'list':
        g.clear() #Changed global state
        d = getListDict(user)
        d['result'] = result
        return templates.list(searchList=[d])
    elif back == 'info':
        machine = validation.testMachineId(user, fields.getfirst('machine_id'))
        d = infoDict(user, machine)
        d['result'] = result
        return templates.info(searchList=[d])
    else:
        raise InvalidInput('back', back, 'Not a known back page.')

def modifyDict(user, fields):
    olddisk = {}
    transaction = ctx.current.create_transaction()
    try:
        machine = validation.testMachineId(user, fields.getfirst('machine_id'))
        owner = validation.testOwner(user, fields.getfirst('owner'), machine)
        admin = validation.testAdmin(user, fields.getfirst('administrator'),
                                     machine)
        contact = validation.testContact(user, fields.getfirst('contact'),
                                         machine)
        name = validation.testName(user, fields.getfirst('name'), machine)
        oldname = machine.name
        command = "modify"

        memory = fields.getfirst('memory')
        if memory is not None:
            memory = validation.validMemory(user, memory, machine, on=False)
            machine.memory = memory
 
        disksize = validation.testDisk(user, fields.getfirst('disk'))
        if disksize is not None:
            disksize = validation.validDisk(user, disksize, machine)
            disk = machine.disks[0]
            if disk.size != disksize:
                olddisk[disk.guest_device_name] = disksize
                disk.size = disksize
                ctx.current.save(disk)
        
        if owner is not None:
            machine.owner = owner
        if name is not None:
            machine.name = name
        if admin is not None:
            machine.administrator = admin
        if contact is not None:
            machine.contact = contact
            
        ctx.current.save(machine)
        transaction.commit()
    except:
        transaction.rollback()
        raise
    for diskname in olddisk:
        controls.resizeDisk(oldname, diskname, str(olddisk[diskname]))
    if name is not None:
        controls.renameMachine(machine, oldname, name)
    return dict(user=user,
                command=command,
                machine=machine)
    
def modify(user, fields):
    """Handler for modifying attributes of a machine."""
    try:
        modify_dict = modifyDict(user, fields)
    except InvalidInput, err:
        result = None
        machine = validation.testMachineId(user, fields.getfirst('machine_id'))
    else:
        machine = modify_dict['machine']
        result = 'Success!'
        err = None
    info_dict = infoDict(user, machine)
    info_dict['err'] = err
    if err:
        for field in fields.keys():
            setattr(info_dict['defaults'], field, fields.getfirst(field))
    info_dict['result'] = result
    return templates.info(searchList=[info_dict])
    

def helpHandler(user, fields):
    """Handler for help messages."""
    simple = fields.getfirst('simple')
    subjects = fields.getlist('subject')
    
    help_mapping = dict(paravm_console="""
ParaVM machines do not support console access over VNC.  To access
these machines, you either need to boot with a liveCD and ssh in or
hope that the sipb-xen maintainers add support for serial consoles.""",
                        hvm_paravm="""
HVM machines use the virtualization features of the processor, while
ParaVM machines use Xen's emulation of virtualization features.  You
want an HVM virtualized machine.""",
                        cpu_weight="""
Don't ask us!  We're as mystified as you are.""",
                        owner="""
The owner field is used to determine <a
href="help?subject=quotas">quotas</a>.  It must be the name of a
locker that you are an AFS administrator of.  In particular, you or an
AFS group you are a member of must have AFS rlidwka bits on the
locker.  You can check see who administers the LOCKER locker using the
command 'fs la /mit/LOCKER' on Athena.)  See also <a
href="help?subject=administrator">administrator</a>.""",
                        administrator="""
The administrator field determines who can access the console and
power on and off the machine.  This can be either a user or a moira
group.""",
                        quotas="""
Quotas are determined on a per-locker basis.  Each quota may have a
maximum of 512 megabytes of active ram, 50 gigabytes of disk, and 4
active machines.""",
                        console="""
<strong>Framebuffer:</strong> At a Linux boot prompt in your VM, try
setting <tt>fb=false</tt> to disable the framebuffer.  If you don't,
your machine will run just fine, but the applet's display of the
console will suffer artifacts.
"""
                   )
    
    if not subjects:
        subjects = sorted(help_mapping.keys())
        
    d = dict(user=user,
             simple=simple,
             subjects=subjects,
             mapping=help_mapping)
    
    return templates.help(searchList=[d])
    

def badOperation(u, e):
    raise CodeError("Unknown operation")

def infoDict(user, machine):
    status = controls.statusInfo(machine)
    checkpoint.checkpoint('Getting status info')
    has_vnc = hasVnc(status)
    if status is None:
        main_status = dict(name=machine.name,
                           memory=str(machine.memory))
        uptime = None
        cputime = None
    else:
        main_status = dict(status[1:])
        start_time = float(main_status.get('start_time', 0))
        uptime = datetime.timedelta(seconds=int(time.time()-start_time))
        cpu_time_float = float(main_status.get('cpu_time', 0))
        cputime = datetime.timedelta(seconds=int(cpu_time_float))
    checkpoint.checkpoint('Status')
    display_fields = """name uptime memory state cpu_weight on_reboot 
     on_poweroff on_crash on_xend_start on_xend_stop bootloader""".split()
    display_fields = [('name', 'Name'),
                      ('owner', 'Owner'),
                      ('administrator', 'Administrator'),
                      ('contact', 'Contact'),
                      ('type', 'Type'),
                      'NIC_INFO',
                      ('uptime', 'uptime'),
                      ('cputime', 'CPU usage'),
                      ('memory', 'RAM'),
                      'DISK_INFO',
                      ('state', 'state (xen format)'),
                      ('cpu_weight', 'CPU weight'+helppopup('cpu_weight')),
                      ('on_reboot', 'Action on VM reboot'),
                      ('on_poweroff', 'Action on VM poweroff'),
                      ('on_crash', 'Action on VM crash'),
                      ('on_xend_start', 'Action on Xen start'),
                      ('on_xend_stop', 'Action on Xen stop'),
                      ('bootloader', 'Bootloader options'),
                      ]
    fields = []
    machine_info = {}
    machine_info['name'] = machine.name
    machine_info['type'] = machine.type.hvm and 'HVM' or 'ParaVM'
    machine_info['owner'] = machine.owner
    machine_info['administrator'] = machine.administrator
    machine_info['contact'] = machine.contact

    nic_fields = getNicInfo(machine_info, machine)
    nic_point = display_fields.index('NIC_INFO')
    display_fields = (display_fields[:nic_point] + nic_fields + 
                      display_fields[nic_point+1:])

    disk_fields = getDiskInfo(machine_info, machine)
    disk_point = display_fields.index('DISK_INFO')
    display_fields = (display_fields[:disk_point] + disk_fields + 
                      display_fields[disk_point+1:])
    
    main_status['memory'] += ' MiB'
    for field, disp in display_fields:
        if field in ('uptime', 'cputime') and locals()[field] is not None:
            fields.append((disp, locals()[field]))
        elif field in machine_info:
            fields.append((disp, machine_info[field]))
        elif field in main_status:
            fields.append((disp, main_status[field]))
        else:
            pass
            #fields.append((disp, None))

    checkpoint.checkpoint('Got fields')


    max_mem = validation.maxMemory(user, machine, False)
    checkpoint.checkpoint('Got mem')
    max_disk = validation.maxDisk(user, machine)
    defaults = Defaults()
    for name in 'machine_id name administrator owner memory contact'.split():
        setattr(defaults, name, getattr(machine, name))
    defaults.disk = "%0.2f" % (machine.disks[0].size/1024.)
    checkpoint.checkpoint('Got defaults')
    d = dict(user=user,
             cdroms=CDROM.select(),
             on=status is not None,
             machine=machine,
             defaults=defaults,
             has_vnc=has_vnc,
             uptime=str(uptime),
             ram=machine.memory,
             max_mem=max_mem,
             max_disk=max_disk,
             owner_help=helppopup("owner"),
             fields = fields)
    return d

def info(user, fields):
    """Handler for info on a single VM."""
    machine = validation.testMachineId(user, fields.getfirst('machine_id'))
    d = infoDict(user, machine)
    checkpoint.checkpoint('Got infodict')
    return templates.info(searchList=[d])

mapping = dict(list=listVms,
               vnc=vnc,
               command=command,
               modify=modify,
               info=info,
               create=create,
               help=helpHandler)

def printHeaders(headers):
    for key, value in headers.iteritems():
        print '%s: %s' % (key, value)
    print


def getUser():
    """Return the current user based on the SSL environment variables"""
    username = os.environ['SSL_CLIENT_S_DN_Email'].split("@")[0]
    return username

def main(operation, user, fields):    
    start_time = time.time()
    fun = mapping.get(operation, badOperation)

    if fun not in (helpHandler, ):
        connect('postgres://sipb-xen@sipb-xen-dev.mit.edu/sipb_xen')
    try:
        checkpoint.checkpoint('Before')
        output = fun(u, fields)
        checkpoint.checkpoint('After')

        headers = dict(DEFAULT_HEADERS)
        if isinstance(output, tuple):
            new_headers, output = output
            headers.update(new_headers)
        e = revertStandardError()
        if e:
            output.addError(e)
        printHeaders(headers)
        output_string =  str(output)
        checkpoint.checkpoint('output as a string')
        print output_string
        print '<pre>%s</pre>' % checkpoint
    except Exception, err:
        if not fields.has_key('js'):
            if isinstance(err, CodeError):
                print 'Content-Type: text/html\n'
                e = revertStandardError()
                print error(operation, u, fields, err, e)
                sys.exit(1)
            if isinstance(err, InvalidInput):
                print 'Content-Type: text/html\n'
                e = revertStandardError()
                print invalidInput(operation, u, fields, err, e)
                sys.exit(1)
        print 'Content-Type: text/plain\n'
        print 'Uh-oh!  We experienced an error.'
        print 'Please email sipb-xen@mit.edu with the contents of this page.'
        print '----'
        e = revertStandardError()
        print e
        print '----'
        raise

if __name__ == '__main__':
    fields = cgi.FieldStorage()
    u = getUser()
    g.user = u
    operation = os.environ.get('PATH_INFO', '')
    if not operation:
        print "Status: 301 Moved Permanently"
        print 'Location: ' + os.environ['SCRIPT_NAME']+'/\n'
        sys.exit(0)

    if operation.startswith('/'):
        operation = operation[1:]
    if not operation:
        operation = 'list'

    if os.getenv("SIPB_XEN_PROFILE"):
        import profile
        profile.run('main(operation, u, fields)', 'log-'+operation)
    else:
        main(operation, u, fields)
