#!/usr/bin/env python3

"""FelixSupervisor

FelixSupervisor starts all FELIX applications in sequence and monitors their
execution. It is registered as Supervisor eventlistener.
"""

import logging
import sys
import os
import argparse
from felix_starter.felix_starter import FelixStarter


class FelixSupervisor:

    def __init__(self, sv_config='', enable_ers=False, ers_config=None):
        self.hostname = os.getenv('HOSTNAME')
        self.run_monitor = True
        self.ipc_partition = None
        self.stages = {}
        self.logger = None
        self.ers_logging = enable_ers
        self.ers_conf = ers_config
        self.setup_logger()
        release = os.path.dirname(os.path.abspath(__file__))
        self.logger.info("Felix-supervisor of release %s started on %s", release, self.hostname)
        if sv_config == '':
            address = "http://localhost:9001/RPC2"
        else:
            auth = self.get_auth(sv_config)
            address = "http://{0}:{1}@localhost:9001/RPC2".format(auth["usr"], auth["pwd"])
        self.rpc = FelixStarter(server_proxy=address, ers_logging=enable_ers)

    def get_auth(self, sv_config):
        auth = {}
        with open(sv_config, "r") as f:
            for line in f:
                if "username" in line:
                    auth["usr"] = line.split('=')[1].strip(' ')
                if "password" in line and "SHA1" not in line:
                    auth["pwd"] = line.split('=')[1].strip(' ')
                if len(auth.keys()) > 1:
                    break
        if "usr" not in auth.keys() or "pwd" not in auth.keys():
            if self.ers_logging:
                self.logger.fatal(self.AuthError())
            else:
                self.logger.fatal("Cannot retrieve RPC interface username and password from config file")

            exit(1)
        else:
            return auth

    def setup_logger(self):
        self.logger = logging.getLogger('root')
        self.logger.setLevel(logging.INFO)

        if self.ers_logging:
            import ers
            from ipc import IPCPartition
            from felix_starter.felix_issues import AuthError, TaskExited, TaskFatal
            self.AuthError = AuthError
            self.TaskExited = TaskExited
            self.TaskFatal = TaskFatal
            self.ipc_partition = IPCPartition(self.ers_conf['TDAQ_PARTITION'])
            if not self.ipc_partition.isValid():
                self.logger.warning("Paritition %s is invalid", self.ers_conf['TDAQ_PARTITION'])
            os.environ["TDAQ_APPLICATION_NAME"] = 'felix-sv-' + self.hostname
            for key, value in self.ers_conf.items():
                os.environ[key] = value

            self.ers_handler = ers.LoggingHandler()
            self.logger.addHandler(self.ers_handler)
            self.logger.debug("Felix-supervisor of %s reporting to %s paritition.", self.hostname, self.ers_conf['TDAQ_PARTITION'])

        else:
            FORMAT = '%(filename)s %(asctime)s %(levelname)s %(message)s'
            sh = logging.StreamHandler(stream=sys.stderr)
            sh.setLevel(logging.DEBUG)
            formatter = logging.Formatter(FORMAT)
            sh.setFormatter(formatter)
            self.logger.addHandler(sh)

    def determine_stages(self):
        info = self.rpc.sv.getAllProcessInfo()
        for entry in info:
            try:
                stage_no = int(entry['group'].split('_')[0])
            except ValueError:
                self.logger.debug("Found invalid group name %s.", entry['group'])
                continue
            blocking = True if 'blocking' in entry['group'] else False
            if stage_no not in self.stages:
                self.stages[stage_no] = {}
                self.stages[stage_no]['procs'] = []
            self.stages[stage_no]['procs'].append(entry['group']+':'+entry['name'])
            self.stages[stage_no]['block'] = blocking

    def synch_start_list(self, tasks):
        """Starts a list of processes waiting until thir termination"""
        ret = self.rpc.synch_start_tasks(tasks)
        return ret

    def asynch_start_list(self, tasks):
        """Starts a list of processes"""
        self.rpc.start_tasks(tasks, wait=True)

    def asynch_start_group(self, group):
        """Starts a list of processes"""
        self.rpc.start_group(group)

    def reply(self, msg):
        sys.stdout.write(msg)
        sys.stdout.flush()

    def monitor(self):
        while self.run_monitor:

            # transition from ACKNOWLEDGED to READY
            self.reply('READY\n')
            line = sys.stdin.readline()
            message = dict([x.split(':') for x in line.split()])

            raw_payload = sys.stdin.read(int(message['len']))
            payload = dict([x.split(':') for x in raw_payload.split()])

            message.update(payload)
            self.logger.debug('Message: %s', message)
            self.process_event(message)

            # transition from READY to ACKNOWLEDGED
            self.reply('RESULT 2\nOK')

    def process_event(self, data):
        eventname = data['eventname']
        task = data['processname']

        if eventname == 'PROCESS_STATE_EXITED':
            if data['expected'] == '0':
                if self.ers_logging:
                    self.logger.error(self.TaskExited(task))
                else:
                    self.logger.error('Task %s exited unexpectedly.', task)
            else:
                self.logger.info('Task %s exited with no errors.', task)

        if eventname == 'PROCESS_STATE_STARTING':
            self.logger.info('Starting %s', task)

        if eventname == 'PROCESS_STATE_BACKOFF':
            self.logger.info('Start of process %s failed. Tries %d.', task, int(data['tries']))

        if eventname == 'PROCESS_STATE_FATAL':
            if self.ers_logging:
                self.logger.error(self.TaskFatal(task))
            else:
                self.logger.error('Start of process %s failed too many times. It will not be restarted.', task)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-c', dest='config', type=str, default='', help='Supervisor config file')
    args = parser.parse_args()

    ers_config = {
        'TDAQ_PARTITION': "initial",
        'TDAQ_ERS_ERROR': "lstderr,mts",
        'TDAQ_ERS_FATAL': "lstderr,mts",
        'TDAQ_ERS_WARNING': "lstderr,mts",
        'TDAQ_ERS_INFO': "lstderr,mts",
        'TDAQ_ERS_LOG': "lstderr,mts",
        'TDAQ_ERS_DEBUG': "lstderr,mts"
    }
    sv = FelixSupervisor(sv_config=args.config, enable_ers=True, ers_config=ers_config)
    sv.determine_stages()
    for _no, stage in sorted(sv.stages.items()):
        if stage['block'] is True:
            ret = sv.synch_start_list(stage['procs'])
            if ret != 0:
                break
        else:
            sv.asynch_start_list(stage['procs'])
    sv.monitor()
