#!/usr/bin/env python3

import argparse
import os
import shutil
import sys
import random
import copy
from pathlib import Path
from colorama import Fore, Style
from atom_core.config_io import atomError

# Atom imports
from atom_core.dataset_io import (loadJSONFile,
                                  createJSONFile, saveAtomDataset)
from atom_core.system import execute

# -------------------------------------------------------------------------------
# --- FUNCTIONS
# -------------------------------------------------------------------------------


def filterDatasetByCollections(dataset, collection_list):
    original_dataset = copy.deepcopy(dataset)
    for collection_key in original_dataset['collections']:
        if collection_key not in collection_list:
            del dataset['collections'][collection_key]

# -------------------------------------------------------------------------------
# --- MAIN
# -------------------------------------------------------------------------------


def main():
    ap = argparse.ArgumentParser(description='Splits an ATOM dataset json into train and test datasets.')
    ap.add_argument("-ds1", "--dataset_1_json",
                    help="Dataset 1 json.", type=str, required=True)
    ap.add_argument("-ds2", "--dataset_2_json",
                    help="Dataset 2 json.", type=str, required=True)
    ap.add_argument("-dof", "--dataset_out_folder",
                    help="Folder on which to put the dataset out.", type=str, required=True)
    ap.add_argument("-ow", "--overwrite", action="store_true", default=False,
                    help="Overwrites dataset_out_folder.")

    args = vars(ap.parse_known_args()[0])

    args['dataset_1_json'] = str(Path(args['dataset_1_json']))
    args['dataset_2_json'] = str(Path(args['dataset_2_json']))
    dataset_1_folder = os.path.dirname(args['dataset_1_json'])
    dataset_2_folder = os.path.dirname(args['dataset_2_json'])

    # ---------------------------------------
    # --- Verify that the dataset is ok
    # ---------------------------------------
    if not os.path.isfile(args['dataset_1_json']):
        sys.exit(Fore.RED + 'Dataset ' + args['dataset_1_json'] + ' does not exist.' + Style.RESET_ALL)

    if not os.path.isfile(args['dataset_2_json']):
        sys.exit(Fore.RED + 'Dataset ' + args['dataset_2_json'] + ' does not exist.' + Style.RESET_ALL)

    # Load a json file
    print('Loading json file ' + args['dataset_1_json'])
    dataset_1 = loadJSONFile(args['dataset_1_json'])

    print('Loading json file ' + args['dataset_2_json'])
    dataset_2 = loadJSONFile(args['dataset_2_json'])

    # Create folder
    if os.path.exists(args['dataset_out_folder']):
        if args['overwrite']:
            shutil.rmtree(args['dataset_out_folder'])
        else:
            atomError("Dataset out folder already exists. If you want to overwrite use flag -ow")

    new_path = shutil.copytree(dataset_1_folder, args['dataset_out_folder'])

    # Run the actual merge
    # NOTE this assumes that all data apart from collections is similar in dataset_1 and dataset_2.
    # we are not testing this for now ...
    dataset_merged = copy.deepcopy(dataset_1)

    collection_keys = dataset_merged['collections'].keys()
    collection_keys = list(map(lambda x: int(x), collection_keys))
    print(collection_keys)
    next_collection = max(collection_keys)

    for collection_key, collection in dataset_2["collections"].items():
        next_collection = next_collection + 1
        next_collection_key = str(next_collection).zfill(3)

        print('Adding collection ' + collection_key + ' from dataset ' + args['dataset_2_json'] + ' to output dataset ' + Fore.BLUE +
              args['dataset_out_folder'] + Style.RESET_ALL + ' as collection ' + Fore.BLUE + next_collection_key + Style.RESET_ALL)

        dataset_merged['collections'][next_collection_key] = collection

        # change the datafile fields
        for sensor_key, sensor in dataset_2['collections'][collection_key]['data'].items():
            data_file = sensor['data_file']
            print(data_file)
            fields = data_file.split('.')
            name = fields[0]
            ext = fields[1]
            new_data_file = name[:-3] + next_collection_key + '.' + ext
            print('Changing data_file from ' + sensor['data_file'] + ' to ' + new_data_file)
            dataset_merged['collections'][next_collection_key]['data'][sensor_key]['data_file'] = new_data_file

            # Copy file
            full_old_data_file = dataset_2_folder + '/' + data_file
            full_new_data_file = args['dataset_out_folder'] + '/' + new_data_file

            print('Copying from data_file from ' + full_old_data_file + ' to ' + full_new_data_file)
            shutil.copy(full_old_data_file, full_new_data_file)

    # Save datasets
    dataset_merged['_metadata']['dataset_name'] += 'merged'
    # createJSONFile(args['dataset_out'], dataset_merged)

    print('Writing updated dataset.json')
    saveAtomDataset(args['dataset_out_folder']+'/dataset.json', dataset_merged,
                    save_data_files=False, freeze_dataset=False)


if __name__ == "__main__":
    main()
