#!/usr/bin/env python3

"""
Reads the calibration results from a json file and computes the evaluation metrics
"""

# -------------------------------------------------------------------------------
# --- IMPORTS
# -------------------------------------------------------------------------------

import json
import os
import sys

import argparse
import cv2
from matplotlib import cm
import numpy as np
from prettytable import PrettyTable
from collections import OrderedDict
from scipy.spatial import distance
from colorama import Style, Fore
from copy import deepcopy

# Atom imports
from atom_core.atom import getTransform
from atom_core.dataset_io import copyTFToDataset, filterPatternsFromDataset, getMixedDataset, loadResultsJSON, filterCollectionsFromDataset, filterSensorsFromDataset
from atom_core.drawing import drawCross2D, drawSquare2D
from atom_core.utilities import atomStartupPrint, createLambdaExpressionsForArgs, saveFileResults, atomError
from atom_core.vision import projectToCamera


# -------------------------------------------------------------------------------
# --- FUNCTIONS
# -------------------------------------------------------------------------------
def getPointsInPatternAsNPArray(_collection_key, _pattern_key, _sensor_key, _dataset):
    pts_in_pattern_list = []  # collect the points
    for pt_detected in _dataset['collections'][_collection_key]['labels'][_pattern_key][
            _sensor_key]['idxs']:
        id_detected = pt_detected['id']
        point = [item for item in _dataset['patterns'][_pattern_key]
                 ['corners'] if item['id'] == id_detected][0]
        pts_in_pattern_list.append(point)

    return np.array([[item['x'] for item in pts_in_pattern_list],  # convert list to np array
                     [item['y'] for item in pts_in_pattern_list],
                     [0 for _ in pts_in_pattern_list],
                     [1 for _ in pts_in_pattern_list]], float)


def getPointsDetectedInImageAsNPArray(_collection_key, _pattern_key, _sensor_key, _dataset):
    return np.array(
        [[item['x']
          for item in _dataset['collections'][_collection_key]['labels'][_pattern_key]
          [_sensor_key]['idxs']],
         [item['y']
          for item in _dataset['collections'][_collection_key]['labels'][_pattern_key]
          [_sensor_key]['idxs']]],
        dtype=float)

# -------------------------------------------------------------------------------
# --- MAIN
# -------------------------------------------------------------------------------


if __name__ == "__main__":

    atomStartupPrint('Single rgb evaluation')

    ap = argparse.ArgumentParser()
    ap.add_argument("-train_json", "--train_json_file",
                    help="Json file containing input training dataset.", type=str, required=True)
    ap.add_argument("-test_json", "--test_json_file",
                    help="Json file containing input testing dataset.", type=str, required=True)
    ap.add_argument(
        "-csf", "--collection_selection_function", default=None, type=str,
        help="A string to be evaluated into a lambda function that receives a collection name as input and "
        "returns True or False to indicate if the collection should be loaded (and used in the "
        "optimization). The Syntax is lambda name: f(x), where f(x) is the function in python "
        "language. Example: lambda name: int(name) > 5 , to load only collections 6, 7, and onward.")
    ap.add_argument(
        "-ssf", "--sensor_selection_function", default=None, type=str,
        help="A string to be evaluated into a lambda function that receives a sensor name as input and "
        "returns True or False to indicate if the sensor should be loaded (and used in the "
        "optimization). The Syntax is lambda name: f(x), where f(x) is the function in python "
        'language. Example: lambda name: name in ["left_laser", "frontal_camera"] , to load only '
        "sensors left_laser and frontal_camera")
    # save results in a csv file
    ap.add_argument("-sfr", "--save_file_results", help="Store the results",
                    action='store_true', default=False)
    ap.add_argument(
        "-sfrn", "--save_file_results_name",
        help="Name of csv file to save the results. "
        "Default: -test_json/results/{name_of_dataset}_{sensor_source}_to_{sensor_target}_results.csv",
        type=str, required=False)
    ap.add_argument("-uic", "--use_incomplete_collections", action="store_true", default=False,
                    help="Remove any collection which does not have a detection for all sensors.", )
    ap.add_argument(
        "-rpd", "--remove_partial_detections",
        help="Remove detected labels which are only partial."
        "Used or the Charuco.", action="store_true", default=False)
    ap.add_argument(
        "-psf", "--pattern_selection_function", default=None, type=str,
        help="A string to be evaluated into a lambda function that receives a pattern name as input (as defined in the config.yml) and returns True or False to indicate if the pattern should be used in the optimization). The Syntax is lambda name: f(x), where f(x) is the function in python "
        "language. Example: lambda name: name in ['pattern_1', 'charuco_200x300'] , to load only these two patterns")
    ap.add_argument("-si", "--show_images", help="If true the script shows images.",
                    action='store_true', default=False)
    ap.add_argument(
        "-cpt", "--copy_pattern_transforms",
        help="If true the transformations of the pattern will be copied from train to mixed dataset. See https://github.com/lardemua/atom/issues/870",
        action='store_true', default=False)

    # - Save args
    # args = vars(ap.parse_known_args()[0])
    arglist = [x for x in sys.argv[1:] if not x.startswith("__")]
    # these args have the selection functions as strings
    args_original = vars(ap.parse_args(args=arglist))
    args = createLambdaExpressionsForArgs(args_original)  # selection functions are now lambdas

    # ---------------------------------------
    # --- INITIALIZATION Read calibration data from file
    # ---------------------------------------
    # Loads the train json file containing the calibration results
    train_dataset, train_json_file = loadResultsJSON(
        args["train_json_file"], args["collection_selection_function"])

    # Loads the test json file containing a set of collections to evaluate the calibration
    test_dataset, test_json_file = loadResultsJSON(
        args["test_json_file"],
        args["collection_selection_function"])

    # ---------------------------------------
    # --- Filter some collections and / or sensors from the dataset
    # ---------------------------------------
    test_dataset = filterCollectionsFromDataset(test_dataset, args)  # filter collections
    test_dataset = filterSensorsFromDataset(test_dataset, args)  # filter sensors

    # ---------------------------------------
    # --- Get mixed json (calibrated transforms from train and the rest from test)
    # ---------------------------------------
    mixed_dataset = getMixedDataset(train_dataset, test_dataset)
    mixed_dataset = filterPatternsFromDataset(mixed_dataset, args)

    # ---------------------------------------
    # --- Copy transforms of the patterns (if they are fixed).
    # ---------------------------------------
    # NOTE Not sure if we should do this all the time. see discussion in
    # https://github.com/lardemua/atom/issues/870
    if args['copy_pattern_transforms']:
        for pattern_key, pattern in train_dataset['calibration_config']['calibration_patterns'].items():
            if pattern['fixed'] is True:
                copyTFToDataset(
                    pattern['parent_link'],
                    pattern['link'],
                    train_dataset, mixed_dataset)

    # ---------------------------------------
    # --- INITIALIZATION Read evaluation data from file ---> if desired <---
    # ---------------------------------------
    e = {}
    for collection_key, collection in mixed_dataset['collections'].items():
        if 'labels' not in collection:
            print(f'Collection {collection_key} does not have labels information. Skipping...')
            continue
        e[collection_key] = {}

        for pattern_key, pattern in mixed_dataset['calibration_config']['calibration_patterns'].items():
            e[collection_key][pattern_key] = {}
            # Get pattern number of corners
            nx = mixed_dataset['calibration_config']['calibration_patterns'][pattern_key][
                'dimension']['x']
            ny = mixed_dataset['calibration_config']['calibration_patterns'][pattern_key][
                'dimension']['y']

            for sensor_key, sensor in mixed_dataset['sensors'].items():  # iterate all sensors
                if sensor['modality'] != 'rgb':
                    continue

                # pattern not detected by sensor in collection
                if not collection['labels'][pattern_key][sensor_key]['detected']:
                    continue

                # Get the pattern corners in the local pattern frame. Must use only corners which have -----------------
                # correspondence to the detected points stored in collection['labels'][sensor_key]['idxs'] -------------
                pts_in_pattern = getPointsInPatternAsNPArray(
                    collection_key, pattern_key, sensor_key, mixed_dataset)
                # print('pts_in_pattern' + str(pattern_key))
                # print(pts_in_pattern)

                # Transform the pts from the pattern's reference frame to the sensor's reference frame -----------------
                from_frame = sensor['parent']
                to_frame = mixed_dataset['calibration_config']['calibration_patterns'][
                    pattern_key]['link']
                sensor_to_pattern = getTransform(from_frame, to_frame, collection['transforms'])

                pts_in_sensor = np.dot(sensor_to_pattern, pts_in_pattern)

                # q = transformations.quaternion_from_matrix(sensor_to_pattern)
                # print('T =\n' + str(sensor_to_pattern) + '\nquat = ' + str(q))

                # Project points to the image of the sensor ------------------------------------------------------------
                w, h = collection['data'][sensor_key]['width'], collection['data'][sensor_key][
                    'height']
                K = np.ndarray((3, 3), buffer=np.array(sensor['camera_info']['K']), dtype=float)
                D = np.ndarray((5, 1), buffer=np.array(sensor['camera_info']['D']), dtype=float)

                pts_in_image, _, _ = projectToCamera(K, D, w, h, pts_in_sensor[0:3, :])

                # Get the detected points to use as ground truth--------------------------------------------------------
                pts_detected_in_image = getPointsDetectedInImageAsNPArray(
                    collection_key, pattern_key, sensor_key, mixed_dataset)

                if args['show_images']:
                    print('showing image for collection ' + collection_key +
                          ' pattern ' + pattern_key + ' sensor ' + sensor_key)

                    image_path = collection['data'][sensor_key]['data_file']
                    dataset_path = os.path.dirname(test_json_file)
                    full_image_path = dataset_path + '/' + image_path
                    print(full_image_path)
                    image = cv2.imread(full_image_path)

                    # Draw labels on target image (squares)
                    print(pts_in_image)
                    print(pts_in_image.shape)

                    lst_pts = pts_in_image.tolist()
                    print(lst_pts)
                    xs = lst_pts[0]
                    ys = lst_pts[1]
                    print(xs)

                    cmap = cm.gist_rainbow(np.linspace(0, 1, nx * ny))
                    for x, y in zip(xs, ys):

                        drawCross2D(image, x, y, 15, color=(255, 0, 0), thickness=1)

                        # drawSquare2D(image, x_t, y_t, 6, color=color, thickness=1)

                    cv2.namedWindow('image', cv2.WINDOW_NORMAL)
                    cv2.imshow('image', image)
                    cv2.waitKey(0)
                    # exit(0)

                corners_errors = []
                for idx, label_idx in enumerate(
                        collection['labels'][pattern_key][sensor_key]['idxs']):
                    corners_errors.append(
                        np.sqrt(
                            (pts_in_image[0, idx] - pts_detected_in_image[0, idx]) ** 2 +
                            (pts_in_image[1, idx] - pts_detected_in_image[1, idx]) ** 2))
                e[collection_key][pattern_key][sensor_key] = np.mean(corners_errors)

    # print(json.dumps(e, indent=4))

    if not e:
        atomError('No rgb labelling found. Exiting...')

    # -------------------------------------------------------------
    # Print output table
    # -------------------------------------------------------------
    table_header = ['Collection']

    for pattern_key, pattern in mixed_dataset['calibration_config']['calibration_patterns'].items():
        for sensor_key, sensor in mixed_dataset['sensors'].items():
            if sensor['modality'] not in ['rgb']:
                continue

            column_name = sensor_key + '-' + pattern_key
            table_header.append(column_name + ' [px]')

    table_header.append('all' + ' [px]')
    table = PrettyTable(table_header)
    # table to save. This table was created, because the original has colors and the output csv save them as random characters
    table_to_save = PrettyTable(table_header)

    od = OrderedDict(sorted(mixed_dataset['collections'].items(), key=lambda t: int(t[0])))
    for collection_key, collection in od.items():
        if 'labels' not in collection:
            continue
        row = [collection_key]
        row_save = [collection_key]

        errors = []
        for pattern_key, pattern in mixed_dataset['calibration_config']['calibration_patterns'].items():

            for sensor_key, sensor in mixed_dataset['sensors'].items():  # columns ...
                if sensor['modality'] not in ['rgb']:
                    continue

                # print('Analyzing collection ' + collection_key + ' pattern ' + pattern_key + ' sensor ' + sensor_key)
                if sensor_key in e[collection_key][pattern_key]:

                    value = '%.4f' % e[collection_key][pattern_key][sensor_key]
                    row.append(value)
                    row_save.append(value)
                    errors.append(float(e[collection_key][pattern_key][sensor_key]))
                else:
                    row.append(Fore.LIGHTBLACK_EX + '---' + Style.RESET_ALL)
                    row_save.append('---')

        # Add all as an average of the errors for this row
        if errors:
            average = sum(errors)/len(errors)
            average = '%.4f' % average
            row.append(average)
        else:
            row.append('---')

        table.add_row(row)
        table_to_save.add_row(row)

    # Compute averages and add a bottom row
    bottom_row = []  # Compute averages and add bottom row to table
    bottom_row_save = []
    for col_idx, _ in enumerate(table_header):
        if col_idx == 0:
            bottom_row.append(Fore.BLUE + Style.BRIGHT + 'Averages' + Style.RESET_ALL)
            bottom_row_save.append('Averages')
            continue

        total = 0
        count = 0
        for row in table.rows:
            # if row[col_idx].isnumeric():
            try:
                value = float(row[col_idx])
                total += float(value)
                count += 1
            except:
                pass

        if count == 0:
            value = '---'
        else:
            value = '%.4f' % (total / count)
        bottom_row.append(Fore.BLUE + value + Style.RESET_ALL)
        bottom_row_save.append(value)

    table.add_row(bottom_row)
    table_to_save.add_row(bottom_row_save)

    # Put larger errors in red per column (per sensor)
    for col_idx, _ in enumerate(table_header):
        if col_idx == 0:  # nothing to do
            continue

        max = 0
        max_row_idx = 0
        for row_idx, row in enumerate(table.rows[: -1]):  # ignore bottom row
            try:
                value = float(row[col_idx])
            except:
                continue

            if value > max:
                max = value
                max_row_idx = row_idx

        # set the max column value to red
        table.rows[max_row_idx][col_idx] = Fore.RED + table.rows[max_row_idx][col_idx] + Style.RESET_ALL

    table.align = 'c'
    table_to_save.align = 'c'
    print(Style.BRIGHT + 'Errors per collection' + Style.RESET_ALL)
    print(table)

    # save results in csv file
    if args['save_file_results']:
        if args['save_file_results_name'] is None:
            results_name = f'rgb_results.csv'
            saveFileResults(
                args['train_json_file'],
                args['test_json_file'],
                results_name, table_to_save)
        else:
            with open(args['save_file_results_name'], 'w', newline='') as f_output:
                f_output.write(table_to_save.get_csv_string())

    print('Ending script...')
    sys.exit()
