#!/usr/bin/env python3
"""
Runs several calibration executions in batch_execution. A yml file is used to config the batch executions.
"""
import argparse
import os
import shutil
import subprocess
import yaml
import jinja2
from jinja2 import Environment, FileSystemLoader
from sklearn.model_selection import StratifiedKFold, KFold, LeaveOneOut, StratifiedShuffleSplit

from colorama import Fore, Back, Style

from atom_core.config_io import uriReader
from atom_core.dataset_io import loadJSONFile
from atom_core.system import resolvePath, execute
from atom_core.utilities import atomWarn


def bprint(text):
    """bprint (batch print) will always print in blue color with yellow background """
    print(Fore.BLUE + Back.YELLOW + text + Style.RESET_ALL)


def generateClasses(dataset):
    """
    Generate classes based on an ATOM dataset.
    Classes follow the format detected_pattern--detected_sensor1-detected_sensor2-[...]---[...]

    Args:
        dataset (dict): ATOM dataset.

    Returns:
        tuple: A tuple containing the classes and collection keys.

    """
    classes = []
    collection_keys = list(dataset['collections'].keys())
    for collection_key in collection_keys:
        detected_sensors = []
        class_name = ''
        for pattern_key in dataset['patterns'].keys():
            for sensor_key in dataset['sensors'].keys():
                if dataset['collections'][collection_key]['labels'][pattern_key][sensor_key]['detected']:
                    detected_sensors.append(sensor_key)
            detected_pattern_and_sensors = pattern_key + '--' + '-'.join(detected_sensors)
            class_name += detected_pattern_and_sensors + '---'
        classes.append(class_name.rstrip('---'))
    return classes, collection_keys


def main():
    ap = argparse.ArgumentParser()  # Parse command line arguments
    ap.add_argument("-v", "--verbose", help="Prints the stdout_data of each command to the terminal.",
                    action='store_true', default=False)
    ap.add_argument("-tf", "--template_filename", help="Jinja2 yaml containing the batches.",
                    required=True, type=str)
    ap.add_argument("-df", "--data_filename", help="Yaml containing variables used in the template file.",
                    required=True, type=str)
    ap.add_argument("-of", "--output_folder", help="Folder where to store the results",
                    required=True, type=str)
    ap.add_argument("-ow", "--overwrite", help="Overwrite output folder if needed.",
                    required=False, action='store_true')
    ap.add_argument("-dr", "--dry_run", help="Run without actually executing the processes.",
                    required=False, action='store_true')

    args = vars(ap.parse_args())

    # Load data.yml
    with open(args['data_filename']) as f:
        file_content = f.read()
        data = yaml.safe_load(file_content)

    # Load dataset
    dataset = loadJSONFile(data['dataset_path'] + '/dataset.json')

    # Generate classes and collection keys
    classes, collection_keys = generateClasses(dataset)

    # Safeguard to guarantee backwards compatibility
    if 'cross_validation' not in data:
        data['cross_validation'] = {'type': None, 'n_splits': None}

    # Process cross validation data
    if data['cross_validation']['type'] == 'stratified-k-fold':
        # Use StratifiedKFold for cross validation
        cross_validator = StratifiedKFold(n_splits=data['cross_validation']['n_splits'], shuffle=True)
        # Split the data into folds
        folds = cross_validator.split(collection_keys, classes)

    elif data['cross_validation']['type'] == 'k-fold':
        # Use KFold for cross validation
        cross_validator = KFold(n_splits=data['cross_validation']['n_splits'], shuffle=True)
        # Split the data into folds
        folds = cross_validator.split(collection_keys)

    elif data['cross_validation']['type'] == 'leave-one-out':
        # Use LeaveOneOut for cross validation
        cross_validator = LeaveOneOut()
        # Split the data into folds
        folds = cross_validator.split(collection_keys)    

    elif data['cross_validation']['type'] == 'stratified-shuffle-split' and data['cross_validation']['train_size']:
        # Use StratifiedShuffleSplit for cross validation
        cross_validator = StratifiedShuffleSplit(n_splits=data['cross_validation']['n_splits'], train_size=data['cross_validation']['train_size']) 
        # Split the data into folds
        folds = cross_validator.split(collection_keys, classes)

    else:
        # Print error message if the cross validation type is not supported
        folds  = [(collection_keys, collection_keys)]
        bprint('Running without any cross validation.')

    # Dataset is no longer needed
    del dataset

    # Generate a list of folds
    fold_list = [[list(train), list(test)] for (train, test) in folds]

    # Add folds to data
    data['folds'] = fold_list

    # Template engine1 setup
    file_loader = FileSystemLoader(os.path.dirname(args['template_filename']))
    env = Environment(loader=file_loader, undefined=jinja2.StrictUndefined)
    env.add_extension('jinja2.ext.do')
    template = env.get_template(os.path.basename(args['template_filename']))

    # Print the rendered jinja file just for debug
    rendered = template.render(data)
    with open('auto_rendered.yaml', 'w') as file:
        file.write(rendered)

    config = yaml.safe_load(rendered)

    # Create output folder
    args['output_folder'] = resolvePath(args['output_folder'])
    if not os.path.exists(args['output_folder']):  # create stdout_data folder if it does not exist.
        os.mkdir(args['output_folder'])  # Create the new folder
    elif os.path.exists(args['output_folder']) and args['overwrite']:
        shutil.rmtree(args['output_folder'])  # Create the new folder
        os.mkdir(args['output_folder'])  # Create the new folder

    # Run preprocessing
    print('\n')
    bprint('Executing preprocessing command:\n' + config['preprocessing']['cmd'])

    execute(config['preprocessing']['cmd'], verbose=args['verbose'], save_path=args['output_folder'], save_filename_additions='preprocessing_')

    # Run experiments
    num_experiments = len(config['experiments'].keys())
    for idx, (experiment_key, experiment) in enumerate(config['experiments'].items()):
        bprint(Style.BRIGHT + 'Experiment ' + str(idx) + ' of ' + str(num_experiments-1) + ': ' + experiment_key)
        bprint('Executing command:\n' + experiment['cmd'])


        if args['dry_run']:
            bprint('Running in dry run mode...')
            continue

        experiment_folder = args['output_folder'] + '/' + experiment_key
        if os.path.exists(experiment_folder) and not args['overwrite']:
            atomWarn('Folder ' + experiment_folder + ' exists. Skipping batch experiment.')
            continue
        else:
            os.mkdir(experiment_folder)

        # Start executing command.
        execute(experiment['cmd'], verbose=args['verbose'], save_path=experiment_folder)

        # Collect stdout_data files
        for file in experiment['files_to_collect']:
            if file is None:
                raise ValueError('File in files to collect is None. Aborting.')

            resolved_file, _, _ = uriReader(file)

            if not os.path.exists(resolved_file):
                raise ValueError('File ' + file + ', resolved to ' + resolved_file +
                                 ' should be collected but does not exist.')

            filename_out = experiment_folder + '/' + os.path.basename(resolved_file)
            print(Fore.BLUE + Back.YELLOW + 'Copying file ' + resolved_file + ' to ' + filename_out + Style.RESET_ALL)
            execute('cp ' + resolved_file + ' ' + filename_out, verbose=False)



if __name__ == "__main__":
    main()
