#!/usr/bin/env python3

import argparse
import os
import sys
import random
import copy
from pathlib import Path
from colorama import Fore, Style

# Atom imports
from atom_core.dataset_io import (loadJSONFile,
                                  createJSONFile)
from atom_core.utilities import atomStartupPrint

# -------------------------------------------------------------------------------
# --- FUNCTIONS
# -------------------------------------------------------------------------------


def filterDatasetByCollections(dataset, collection_list):
    original_dataset = copy.deepcopy(dataset)
    for collection_key in original_dataset['collections']:
        if collection_key not in collection_list:
            del dataset['collections'][collection_key]

# -------------------------------------------------------------------------------
# --- MAIN
# -------------------------------------------------------------------------------


def main():

    atomStartupPrint('Splitting dataset')

    ap = argparse.ArgumentParser(description='Splits an ATOM dataset json into train and test datasets.')
    ap.add_argument("-json", "--dataset_json",
                    help="Dataset json file to be split.", type=str, required=True)
    ap.add_argument("-tp", "--train_percentage", help="Percentage of dataset to be used to the train dataset.",
                    required=True, type=int)
    ap.add_argument("-ss", "--sample_seed", help="Sampling seed", type=int, default=None)

    args = vars(ap.parse_known_args()[0])

    args['dataset_json'] = str(Path(args['dataset_json']))

    # Define random seed if given in args
    if args["sample_seed"] is not None:
        random.seed(args["sample_seed"])

    # ---------------------------------------
    # --- Verify that the dataset is ok
    # ---------------------------------------
    if not os.path.isfile(args['dataset_json']):
        sys.exit(Fore.RED + 'Dataset ' + args['dataset_json'] + ' does not exist.' + Style.RESET_ALL)

    # Load a json file
    print('Loading json file ' + args['dataset_json'])
    dataset = loadJSONFile(args['dataset_json'])

    # Get list of collections
    collection_list = list(dataset['collections'].keys())

    # Create train and test collection list
    train_collection_list = random.sample(collection_list,
                                          int(len(collection_list) * args['train_percentage'] / 100))

    test_collection_list = list(set(collection_list) - set(train_collection_list))

    # Create train and test datasets
    train_dataset = copy.deepcopy(dataset)
    test_dataset = copy.deepcopy(dataset)

    # Filter datasets
    filterDatasetByCollections(train_dataset, train_collection_list)
    filterDatasetByCollections(test_dataset, test_collection_list)

    # Save datasets
    train_dataset['_metadata']['dataset_name'] += '_train'
    createJSONFile(args['dataset_json'].replace('.json', '_train.json'), train_dataset)

    test_dataset['_metadata']['dataset_name'] += '_test'
    createJSONFile(args['dataset_json'].replace('.json', '_test.json'), test_dataset)


if __name__ == "__main__":
    main()
