#!/usr/bin/env python3

# Copyright 2015 Intel Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""VSPERF main script.
"""

import logging
import os
import sys
import argparse
import time
import datetime
import shutil
import unittest
import xmlrunner

sys.dont_write_bytecode = True

from conf import settings
from core.loader import Loader
from testcases import TestCase
from tools import tasks
from tools.collectors import collector
from tools.pkt_gen import trafficgen

VERBOSITY_LEVELS = {
    'debug': logging.DEBUG,
    'info': logging.INFO,
    'warning': logging.WARNING,
    'error': logging.ERROR,
    'critical': logging.CRITICAL
}


def parse_arguments():
    """
    Parse command line arguments.
    """
    class _SplitTestParamsAction(argparse.Action):
        """
        Parse and split the '--test-params' argument.

        This expects either 'x=y' or 'x' (implicit true) values.
        """
        def __call__(self, parser, namespace, values, option_string=None):
            results = {}

            for value in values.split(';'):
                result = [key.strip() for key in value.split('=')]
                if len(result) == 1:
                    results[result[0]] = True
                elif len(result) == 2:
                    results[result[0]] = result[1]
                else:
                    raise argparse.ArgumentTypeError(
                        'expected \'%s\' to be of format \'key=val\' or'
                        ' \'key\'' % result)

            setattr(namespace, self.dest, results)

    class _ValidateFileAction(argparse.Action):
        """Validate a file can be read from before using it.
        """
        def __call__(self, parser, namespace, values, option_string=None):
            if not os.path.isfile(values):
                raise argparse.ArgumentTypeError(
                    'the path \'%s\' is not a valid path' % values)
            elif not os.access(values, os.R_OK):
                raise argparse.ArgumentTypeError(
                    'the path \'%s\' is not accessible' % values)

            setattr(namespace, self.dest, values)

    class _ValidateDirAction(argparse.Action):
        """Validate a directory can be written to before using it.
        """
        def __call__(self, parser, namespace, values, option_string=None):
            if not os.path.isdir(values):
                raise argparse.ArgumentTypeError(
                    'the path \'%s\' is not a valid path' % values)
            elif not os.access(values, os.W_OK):
                raise argparse.ArgumentTypeError(
                    'the path \'%s\' is not accessible' % values)

            setattr(namespace, self.dest, values)

    def list_logging_levels():
        """Give a summary of all available logging levels.

	:return: List of verbosity level names in decreasing order of
            verbosity
        """
        return sorted(VERBOSITY_LEVELS.keys(),
                      key=lambda x: VERBOSITY_LEVELS[x])

    parser = argparse.ArgumentParser(prog=__file__, formatter_class=
                                     argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--version', action='version', version='%(prog)s 0.2')
    parser.add_argument('--list', '--list-tests', action='store_true',
                        help='list all tests and exit')
    parser.add_argument('--list-trafficgens', action='store_true',
                        help='list all traffic generators and exit')
    parser.add_argument('--list-collectors', action='store_true',
                        help='list all system metrics loggers and exit')
    parser.add_argument('--list-vswitches', action='store_true',
                        help='list all system vswitches and exit')
    parser.add_argument('--list-settings', action='store_true',
                        help='list effective settings configuration and exit')
    parser.add_argument('test', nargs='*', help='test specification(s)')

    group = parser.add_argument_group('test selection options')
    group.add_argument('-f', '--test-spec', help='test specification file')
    group.add_argument('-d', '--test-dir', help='directory containing tests')
    group.add_argument('-t', '--tests', help='Comma-separated list of terms \
            indicating tests to run. e.g. "RFC2544,!p2p" - run all tests whose\
            name contains RFC2544 less those containing "p2p"')
    group.add_argument('--verbosity', choices=list_logging_levels(),
                       help='debug level')
    group.add_argument('--trafficgen', help='traffic generator to use')
    group.add_argument('--sysmetrics', help='system metrics logger to use')
    group = parser.add_argument_group('test behavior options')
    group.add_argument('--xunit', action='store_true',
                       help='enable xUnit-formatted output')
    group.add_argument('--xunit-dir', action=_ValidateDirAction,
                       help='output directory of xUnit-formatted output')
    group.add_argument('--load-env', action='store_true',
                       help='enable loading of settings from the environment')
    group.add_argument('--conf-file', action=_ValidateFileAction,
                       help='settings file')
    group.add_argument('--test-params', action=_SplitTestParamsAction,
                       help='csv list of test parameters: key=val;...')

    args = vars(parser.parse_args())

    return args


def configure_logging(level):
    """Configure logging.
    """
    log_file_default = os.path.join(
        settings.getValue('LOG_DIR'), settings.getValue('LOG_FILE_DEFAULT'))
    log_file_host_cmds = os.path.join(
        settings.getValue('LOG_DIR'), settings.getValue('LOG_FILE_HOST_CMDS'))
    log_file_traffic_gen = os.path.join(
        settings.getValue('LOG_DIR'),
        settings.getValue('LOG_FILE_TRAFFIC_GEN'))
    log_file_sys_metrics = os.path.join(
        settings.getValue('LOG_DIR'),
        settings.getValue('LOG_FILE_SYS_METRICS'))

    logger = logging.getLogger()
    logger.setLevel(logging.DEBUG)

    stream_logger = logging.StreamHandler(sys.stdout)
    stream_logger.setLevel(VERBOSITY_LEVELS[level])
    stream_logger.setFormatter(logging.Formatter(
        '[%(levelname)s]  %(asctime)s : (%(name)s) - %(message)s'))
    logger.addHandler(stream_logger)

    file_logger = logging.FileHandler(filename=log_file_default)
    file_logger.setLevel(logging.DEBUG)
    logger.addHandler(file_logger)

    class CommandFilter(logging.Filter):
        """Filter out strings beginning with 'cmd :'"""
        def filter(self, record):
            return record.getMessage().startswith(tasks.CMD_PREFIX)

    class TrafficGenCommandFilter(logging.Filter):
        """Filter out strings beginning with 'gencmd :'"""
        def filter(self, record):
            return record.getMessage().startswith(trafficgen.CMD_PREFIX)

    class SystemMetricsCommandFilter(logging.Filter):
        """Filter out strings beginning with 'gencmd :'"""
        def filter(self, record):
            return record.getMessage().startswith(collector.CMD_PREFIX)

    cmd_logger = logging.FileHandler(filename=log_file_host_cmds)
    cmd_logger.setLevel(logging.DEBUG)
    cmd_logger.addFilter(CommandFilter())
    logger.addHandler(cmd_logger)

    gen_logger = logging.FileHandler(filename=log_file_traffic_gen)
    gen_logger.setLevel(logging.DEBUG)
    gen_logger.addFilter(TrafficGenCommandFilter())
    logger.addHandler(gen_logger)

    metrics_logger = logging.FileHandler(filename=log_file_sys_metrics)
    metrics_logger.setLevel(logging.DEBUG)
    metrics_logger.addFilter(SystemMetricsCommandFilter())
    logger.addHandler(metrics_logger)


def apply_filter(tests, tc_filter):
    """Allow a subset of tests to be conveniently selected

    :param tests: The list of Tests from which to select.
    :param tc_filter: A case-insensitive string of comma-separated terms
        indicating the Tests to select.
        e.g. 'RFC' - select all tests whose name contains 'RFC'
        e.g. 'RFC,burst' - select all tests whose name contains 'RFC' or
            'burst'
        e.g. 'RFC,burst,!p2p' - select all tests whose name contains 'RFC'
            or 'burst' and from these remove any containing 'p2p'.
        e.g. '' - empty string selects all tests.
    :return: A list of the selected Tests.
    """
    result = []
    if tc_filter is None:
        tc_filter = ""

    for term in [x.strip() for x in tc_filter.lower().split(",")]:
        if not term or term[0] != '!':
            # Add matching tests from 'tests' into results
            result.extend([test for test in tests \
                if test.name.lower().find(term) >= 0])
        else:
            # Term begins with '!' so we remove matching tests
            result = [test for test in result \
                if test.name.lower().find(term[1:]) < 0]

    return result


class MockTestCase(unittest.TestCase):
    """Allow use of xmlrunner to generate Jenkins compatible output without
    using xmlrunner to actually run tests.

    Usage:
        suite = unittest.TestSuite()
        suite.addTest(MockTestCase('Test1 passed ', True, 'Test1'))
        suite.addTest(MockTestCase('Test2 failed because...', False, 'Test2'))
        xmlrunner.XMLTestRunner(...).run(suite)
    """

    def __init__(self, msg, is_pass, test_name):
        #remember the things
        self.msg = msg
        self.is_pass = is_pass

        #dynamically create a test method with the right name
        #but point the method at our generic test method
        setattr(MockTestCase, test_name, self.generic_test)

        super(MockTestCase, self).__init__(test_name)

    def generic_test(self):
        """Provide a generic function that raises or not based
        on how self.is_pass was set in the constructor"""
        self.assertTrue(self.is_pass, self.msg)


def main():
    """Main function.
    """
    args = parse_arguments()

    # configure settings

    settings.load_from_dir('conf')

    # load command line parameters first in case there are settings files
    # to be used
    settings.load_from_dict(args)

    if args['conf_file']:
        settings.load_from_file(args['conf_file'])

    if args['load_env']:
        settings.load_from_env()

    # reload command line parameters since these should take higher priority
    # than both a settings file and environment variables
    settings.load_from_dict(args)

    configure_logging(settings.getValue('VERBOSITY'))
    logger = logging.getLogger()

    # configure trafficgens

    if args['trafficgen']:
        trafficgens = Loader().get_trafficgens()
        if args['trafficgen'] not in trafficgens:
            logging.error('There are no trafficgens matching \'%s\' found in'
                          ' \'%s\'. Exiting...', args['trafficgen'],
                          settings.getValue('TRAFFICGEN_DIR'))
            sys.exit(1)


    # generate results directory name
    date = datetime.datetime.fromtimestamp(time.time())
    results_dir = "results_" + date.strftime('%Y-%m-%d_%H-%M-%S')
    results_path = os.path.join(settings.getValue('LOG_DIR'), results_dir)

    # configure tests
    testcases = settings.getValue('PERFORMANCE_TESTS')
    all_tests = []
    for cfg in testcases:
        try:
            all_tests.append(TestCase(cfg, results_path))
        except (Exception) as _:
            logger.exception("Failed to create test: %s",
                             cfg.get('Name', '<Name not set>'))
            raise

    # TODO(BOM) Apply filter to select requested tests
    all_tests = apply_filter(all_tests, args['tests'])

    # if required, handle list-* operations

    if args['list']:
        print("Available Tests:")
        print("======")
        for test in all_tests:
            print('* %-18s%s' % ('%s:' % test.name, test.desc))
        exit()

    if args['list_trafficgens']:
        print(Loader().get_trafficgens_printable())
        exit()

    if args['list_collectors']:
        print(Loader().get_collectors_printable())
        exit()

    if args['list_vswitches']:
        print(Loader().get_vswitches_printable())
        exit()

    if args['list_settings']:
        print(str(settings))
        exit()

    # create results directory
    if not os.path.exists(results_dir):
        logger.info("Creating result directory: "  + results_path)
        os.makedirs(results_path)

    suite = unittest.TestSuite()

    # run tests
    for test in all_tests:
        try:
            test.run()
            suite.addTest(MockTestCase('', True, test.name))
        #pylint: disable=broad-except
        except (Exception) as ex:
            logger.exception("Failed to run test: %s", test.name)
            suite.addTest(MockTestCase(str(ex), False, test.name))
            logger.info("Continuing with next test...")

    if settings.getValue('XUNIT'):
        xmlrunner.XMLTestRunner(
            output=settings.getValue('XUNIT_DIR'), outsuffix="",
            verbosity=0).run(suite)

    #remove directory if no result files were created.
    if os.path.exists(results_path):
        if os.listdir(results_path) == []:
            shutil.rmtree(results_path)

if __name__ == "__main__":
    main()