#!/usr/bin/env python3

"""
Reads the calibration results from a json file and computes the evaluation metrics
"""

# -------------------------------------------------------------------------------
# --- IMPORTS
# -------------------------------------------------------------------------------

# Standard imports
import json
import math
import os
import argparse
import sys
from collections import OrderedDict

import numpy as np
import cv2
from prettytable import PrettyTable
from colorama import Style, Fore
from copy import deepcopy

# Atom imports
from atom_core.atom import getTransform
from atom_core.vision import depthToImage
from atom_core.dataset_io import getMixedDataset, readAnnotationFile, loadResultsJSON, filterCollectionsFromDataset
from atom_core.utilities import rootMeanSquare, saveFileResults

# -------------------------------------------------------------------------------
# --- MAIN
# -------------------------------------------------------------------------------


if __name__ == "__main__":

    # ---------------------------------------
    # --- Read commmand line arguments
    # ---------------------------------------
    ap = argparse.ArgumentParser()
    ap.add_argument("-train_json", "--train_json_file", help="Json file containing input training dataset.", type=str,
                    required=True)
    ap.add_argument("-test_json", "--test_json_file", help="Json file containing input testing dataset.", type=str,
                    required=True)
    ap.add_argument("-ds", "--depth_sensor", help="Source transformation sensor.", type=str, required=True)
    ap.add_argument("-cs", "--camera_sensor", help="Target transformation sensor.", type=str, required=True)
    ap.add_argument("-si", "--show_images", help="If true the script shows images.", action='store_true', default=False)
    ap.add_argument("-csf", "--collection_selection_function", default=None, type=lambda s: eval(s, globals()),
                    help="A string to be evaluated into a lambda function that receives a collection name as input and "
                         "returns True or False to indicate if the collection should be loaded (and used in the "
                         "optimization). The Syntax is lambda name: f(x), where f(x) is the function in python "
                         "language. Example: lambda name: int(name) > 5 , to load only collections 6, 7, and onward.")
    ap.add_argument("-uic", "--use_incomplete_collections", action="store_true", default=False,
                    help="Remove any collection which does not have a detection for all sensors.", )
    ap.add_argument("-rpd", "--remove_partial_detections", help="Remove detected labels which are only partial."
                                                                "Used or the Charuco.", action="store_true",
                    default=False)
    ap.add_argument("-pn", "--pattern_name", help="Name of the pattern for which the evaluation will be performed",
                    type=str, default='')

    # save results in a csv file
    ap.add_argument("-sfr", "--save_file_results", help="Store the results", action='store_true', default=False)
    ap.add_argument("-sfrn", "--save_file_results_name", help="Name of csv file to save the results. "
                                                              "Default: -test_json/results/{name_of_dataset}_{sensor_source}_to_{sensor_target}_results.csv",
                    type=str, required=False)

    args = vars(ap.parse_known_args()[0])

    # ---------------------------------------
    # --- INITIALIZATION Read calibration data from file
    # ---------------------------------------
    # Loads the train json file containing the calibration results
    train_dataset, train_json_file = loadResultsJSON(args["train_json_file"], args["collection_selection_function"])

    # Loads the test json file containing a set of collections to evaluate the calibration
    test_dataset, test_json_file = loadResultsJSON(args["test_json_file"], args["collection_selection_function"])

    # ---------------------------------------
    # --- Filter some collections and / or sensors from the dataset
    # ---------------------------------------
    test_dataset = filterCollectionsFromDataset(test_dataset, args)  # filter collections

    annotations, annotations_file = readAnnotationFile(args['test_json_file'], args['camera_sensor'])

    # --- Get mixed json (calibrated transforms from train and the rest from test)
    original_mixed_dataset = getMixedDataset(train_dataset, test_dataset)

    source_frame = original_mixed_dataset['calibration_config']['sensors'][args['camera_sensor']]['link']
    target_frame = original_mixed_dataset['calibration_config']['sensors'][args['depth_sensor']]['link']

    # ---------------------------------------
    # --- INITIALIZATION Read evaluation data from file ---> if desired <---
    # ---------------------------------------
    # Patterns to evaluate
    if args['pattern_name'] == '':
        patterns_to_evaluate = original_mixed_dataset['calibration_config']['calibration_patterns'].keys()
    else:
        patterns_to_evaluate = [args['pattern_name']]

    print('\nStarting evalutation...')

    for pattern_key in patterns_to_evaluate:
        mixed_dataset = deepcopy(original_mixed_dataset)
        # Declare output dict to save the evaluation data if desired
        output_dict = {}
        output_dict['ground_truth_pts'] = {}

        delta_total = []

        e = {}  # dictionary with all the errors
        od = OrderedDict(sorted(mixed_dataset['collections'].items(), key=lambda t: int(t[0])))
        for collection_key, collection in od.items():
            e[collection_key] = {}  # init the dictionary of errors for this collection
            # ---------------------------------------
            # --- Range to image projection
            # ---------------------------------------
            depth2cam = getTransform(source_frame, target_frame,
                                     mixed_dataset['collections'][collection_key]['transforms'])
            first_pattern_key = list(mixed_dataset['calibration_config']['calibration_patterns'].keys())[0]
            pts_in_image = depthToImage(mixed_dataset, collection_key, args['test_json_file'], first_pattern_key,
                                        args['depth_sensor'], args['camera_sensor'], depth2cam)

            # ---------------------------------------
            # --- Get evaluation data for current collection
            # ---------------------------------------
            filename = os.path.dirname(args['test_json_file']) + '/' + collection['data'][args['camera_sensor']][
                'data_file']
            image = cv2.imread(filename)

            #         if args['show_images']:  # draw all ground truth annotations
            #             for side in annotations[collection_key].keys():
            #                 for x, y in zip(annotations[collection_key][side]['ixs'],
            #                                 annotations[collection_key][side]['iys']):
            #                     cv2.line(image, (int(round(x)), int(round(y))), 1, (0, 255, 0))
            #
            # ---------------------------------------
            # --- Evaluation metrics - reprojection error
            # ---------------------------------------
            # -- For each reprojected limit point, find the closest ground truth point and compute the distance to it
            x_errors = []
            y_errors = []
            errors = []
            for idx in range(0, pts_in_image.shape[1]):
                x_proj = pts_in_image[0, idx]
                y_proj = pts_in_image[1, idx]

                # Do not consider points that are re-projected outside of the image
                if x_proj > image.shape[1] or x_proj < 0 or y_proj > image.shape[0] or y_proj < 0:
                    continue

                min_error = sys.float_info.max  # a very large value
                x_min = None
                y_min = None
                for side in annotations[collection_key].keys():
                    for x, y in zip(annotations[collection_key][side]['ixs'],
                                    annotations[collection_key][side]['iys']):

                        error = math.sqrt((x_proj - x) ** 2 + (y_proj - y) ** 2)
                        if error < min_error:
                            min_error = error
                            x_min = x
                            y_min = y

                if x_min and y_min and min_error:
                    x_errors.append(abs(x_proj - x_min))
                    y_errors.append(abs(y_proj - y_min))
                    errors.append(min_error)

                if args['show_images']:
                    # cv2.circle(image, (int(round(x_proj)), int(round(y_proj))), 10, (255, 0, 0), 1)
                    # cv2.circle(image, (int(round(x_proj)), int(round(y_proj))), 10, (255, 0, 0), 1)

                    cv2.line(image, (int(round(x_proj)), int(round(y_proj))),
                             (int(round(x_min)), int(round(y_min))), (0, 255, 255), 1)

                    for side in annotations[collection_key].keys():
                        for x, y in zip(annotations[collection_key][side]['ixs'],
                                        annotations[collection_key][side]['iys']):
                            cv2.line(image, (int(round(x)), int(round(y))), (int(round(x)), int(round(y))), (0, 255, 0),
                                     1)

                    cv2.line(image, (int(round(x_proj)), int(round(y_proj))),
                             (int(round(x_proj)), int(round(y_proj))), (0, 0, 255), 1)

            if not errors:
                print('No Depth point mapped into the image for collection ' + str(collection_key))
                e[collection_key]['x'] = float("nan")
                e[collection_key]['y'] = float("nan")
                e[collection_key]['rms'] = float("nan")
                continue

            e[collection_key]['x'] = np.average(x_errors)
            e[collection_key]['y'] = np.average(y_errors)
            e[collection_key]['rms'] = rootMeanSquare(errors)

            if args['show_images']:
                print('Errors collection ' + collection_key + '\n' + str(e[collection_key]))
                window_name = 'Collection ' + collection_key
                cv2.namedWindow(window_name, cv2.WINDOW_NORMAL)
                cv2.imshow(window_name, image)
                key = cv2.waitKey(0)
                cv2.destroyWindow(window_name)
                if key == ord('q') or key == ord('c'):
                    args['show_images'] = False

        # -------------------------------------------------------------
        # Print output table
        # -------------------------------------------------------------
        table_header = ['Collection #', 'RMS (pix)', 'X err (pix)', 'Y err (pix)']
        table = PrettyTable(table_header)
        table_to_save = PrettyTable(
            table_header)  # table to save. This table was created, because the original has colors and the output csv save them as random characters

        od = OrderedDict(sorted(test_dataset['collections'].items(), key=lambda t: int(t[0])))
        for collection_key, collection in od.items():
            row = [collection_key,
                   '%.4f' % e[collection_key]['rms'],
                   '%.4f' % e[collection_key]['x'],
                   '%.4f' % e[collection_key]['y']]

            table.add_row(row)
            table_to_save.add_row(row)

        # Compute averages and add a bottom row
        bottom_row = []  # Compute averages and add bottom row to table
        bottom_row_save = []
        for col_idx, _ in enumerate(table_header):
            if col_idx == 0:
                bottom_row.append(Fore.BLUE + Style.BRIGHT + 'Averages' + Fore.BLACK + Style.NORMAL)
                bottom_row_save.append('Averages')
                continue

            total = 0
            count = 0
            for row in table.rows:
                # if row[col_idx].isnumeric():
                try:
                    value = float(row[col_idx])
                    if not np.isnan(value):
                        total += float(value)
                        count += 1
                except:
                    pass

            value = '%.4f' % (total / count)
            bottom_row.append(Fore.BLUE + value + Fore.BLACK)
            bottom_row_save.append(value)

        table.add_row(bottom_row)
        table_to_save.add_row(bottom_row_save)

        # Put larger errors in red per column (per sensor)
        for col_idx, _ in enumerate(table_header):
            if col_idx == 0:  # nothing to do
                continue

            max = 0
            max_row_idx = 0
            for row_idx, row in enumerate(table.rows[: -1]):  # ignore bottom row
                try:
                    value = float(row[col_idx])
                except:
                    continue

                if value > max:
                    max = value
                    max_row_idx = row_idx

            # set the max column value to red
            table.rows[max_row_idx][col_idx] = Fore.RED + table.rows[max_row_idx][col_idx] + Style.RESET_ALL

        table.align = 'c'
        table_to_save.align = 'c'
        print(Style.BRIGHT + 'Errors per collection' + Style.RESET_ALL)
        print(table)

        # save results in csv file
        if args['save_file_results']:
            if args['save_file_results_name'] is None:
                results_name = f'{args["depth_sensor"]}_to_{args["camera_sensor"]}_results.csv'
                saveFileResults(args['train_json_file'], args['test_json_file'], results_name, table_to_save)
            else:
                with open(args['save_file_results_name'], 'w', newline='') as f_output:
                    f_output.write(table_to_save.get_csv_string())

    print('Ending script...')
    sys.exit()