#!/usr/bin/env python3

"""
Reads the calibration results from a json file and computes the evaluation metrics
"""

# -------------------------------------------------------------------------------
# --- IMPORTS
# -------------------------------------------------------------------------------

import json
import os
import sys

import argparse
import numpy as np
import cv2
from prettytable import PrettyTable
from collections import OrderedDict
from scipy.spatial import distance
from colorama import Style, Fore
from copy import deepcopy

# Atom imports
from atom_core.atom import getTransform
from atom_core.dataset_io import getMixedDataset, loadResultsJSON, filterCollectionsFromDataset, readAnnotationFile
from atom_core.utilities import saveFileResults
from atom_core.vision import depthInImage, depthToImage

# -------------------------------------------------------------------------------
# --- MAIN
# -------------------------------------------------------------------------------

if __name__ == "__main__":
    ap = argparse.ArgumentParser()
    ap.add_argument("-train_json", "--train_json_file", help="Json file containing input training dataset.", type=str,
                    required=True)
    ap.add_argument("-test_json", "--test_json_file", help="Json file containing input testing dataset.", type=str,
                    required=True)
    ap.add_argument("-ss", "--sensor_source", help="Source transformation sensor.", type=str, required=True)
    ap.add_argument("-st", "--sensor_target", help="Target transformation sensor.", type=str, required=True)
    ap.add_argument("-si", "--show_images", help="If true the script shows images.", action='store_true', default=False)
    ap.add_argument("-bt", "--border_tolerance", help="Define the percentage of pixels to use to create a border. Lidar points outside that border will not count for the error calculations",
                                                 type=float, default=0.025)
    ap.add_argument("-csf", "--collection_selection_function", default=None, type=lambda s: eval(s, globals()),
                    help="A string to be evaluated into a lambda function that receives a collection name as input and "
                    "returns True or False to indicate if the collection should be loaded (and used in the "
                    "optimization). The Syntax is lambda name: f(x), where f(x) is the function in python "
                    "language. Example: lambda name: int(name) > 5 , to load only collections 6, 7, and onward.")
    ap.add_argument("-uic", "--use_incomplete_collections", action="store_true", default=False,
                    help="Remove any collection which does not have a detection for all sensors.", )
    ap.add_argument("-ua", "--use_annotations", help="If true the script uses depth annotations file.", action='store_true', default=False)
    ap.add_argument("-rpd", "--remove_partial_detections", help="Remove detected labels which are only partial."
                            "Used or the Charuco.", action="store_true", default=False)
    ap.add_argument("-pn", "--pattern_name", help="Name of the pattern for which the evaluation will be performed", type=str, default='')
    # save results in a csv file
    ap.add_argument("-sfr", "--save_file_results", help="Store the results", action='store_true', default=False)
    ap.add_argument("-sfrn", "--save_file_results_name", help="Name of csv file to save the results. "
                   "Default: -test_json/results/{name_of_dataset}_{sensor_source}_to_{sensor_target}_results.csv", type=str, required=False)

    # - Save args
    args= vars(ap.parse_known_args()[0])
    depth_sensor_target = args['sensor_target']
    depth_sensor_source = args['sensor_source']
    show_images = args['show_images']
    border_tolerance = args['border_tolerance']

    # ---------------------------------------
    # --- INITIALIZATION Read calibration data from file
    # ---------------------------------------
    # Loads the train json file containing the calibration results
    train_dataset, train_json_file = loadResultsJSON(args["train_json_file"], args["collection_selection_function"])

    # Loads the test json file containing a set of collections to evaluate the calibration
    test_dataset, test_json_file = loadResultsJSON(args["test_json_file"], args["collection_selection_function"])

    # ---------------------------------------
    # --- Filter some collections and / or sensors from the dataset
    # ---------------------------------------
    test_dataset = filterCollectionsFromDataset(test_dataset, args)  # filter collections

    # ---------------------------------------
    # --- Get mixed json (calibrated transforms from train and the rest from test)
    # ---------------------------------------
    original_mixed_dataset = getMixedDataset(train_dataset, test_dataset)

    source_frame = original_mixed_dataset['calibration_config']['sensors'][depth_sensor_source]['link']
    target_frame = original_mixed_dataset['calibration_config']['sensors'][depth_sensor_target]['link']

    if args['use_annotations']:
        annotations, annotations_file = readAnnotationFile(args['test_json_file'], args['sensor_target'])

    # ---------------------------------------
    # --- INITIALIZATION Read evaluation data from file ---> if desired <---
    # ---------------------------------------
    print(Fore.BLUE + '\nStarting evalutation...' + Style.RESET_ALL)


    # Patterns to evaluate
    if args['pattern_name'] == '':
        patterns_to_evaluate = original_mixed_dataset['calibration_config']['calibration_patterns'].keys()
    else:
        patterns_to_evaluate = [args['pattern_name']]
    
    for pattern_key in patterns_to_evaluate:
        mixed_dataset = deepcopy(original_mixed_dataset)
        # Declare output dict to save the evaluation data if desire
        delta_total = []
        output_dict = {}
        output_dict['ground_truth_pts'] = {}
        od = OrderedDict(sorted(mixed_dataset['collections'].items(), key=lambda t: int(t[0])))
        collections_to_skip = []
        e = {}  # dictionary with all the errors
        for collection_key, collection in od.items():
            e[collection_key] = {}  # init the dictionary of errors for this collection
            # ---------------------------------------
            # --- Depth to Depth projection
            # ---------------------------------------
            depth_source_2_depth_target = getTransform(target_frame, source_frame,
                                                        mixed_dataset['collections'][collection_key]['transforms'])

            # ---------------------------------------
            # --- Get evaluation data for current collection
            # ---------------------------------------
            filename = os.path.dirname(test_json_file) + '/' + collection['data'][depth_sensor_target]['data_file']
            image = cv2.imread(filename)
            target_depth_pts_in_img_target = depthInImage(mixed_dataset, collection_key, pattern_key, depth_sensor_target)
            source_depth_pts_in_img_target = depthToImage(
                mixed_dataset, collection_key, test_json_file, pattern_key, depth_sensor_source, depth_sensor_target, depth_source_2_depth_target)



            if args['use_annotations']:
                num_points = 0
                num_points += len(annotations[collection_key]['top']['ixs'])
                num_points += len(annotations[collection_key]['right']['ixs'])
                num_points += len(annotations[collection_key]['bottom']['ixs'])
                num_points += len(annotations[collection_key]['left']['ixs'])

                target_depth_pts_in_img_target = np.ndarray((2, num_points), dtype=float)
                idx = 0
                for x, y in zip(annotations[collection_key]['top']['ixs'], annotations[collection_key]['top']['iys']):
                    target_depth_pts_in_img_target[0, idx] = x
                    target_depth_pts_in_img_target[1, idx] = y
                    idx += 1

                for x, y in zip(annotations[collection_key]['right']['ixs'], annotations[collection_key]['right']['iys']):
                    target_depth_pts_in_img_target[0, idx] = x
                    target_depth_pts_in_img_target[1, idx] = y
                    idx += 1

                for x, y in zip(annotations[collection_key]['bottom']['ixs'], annotations[collection_key]['bottom']['iys']):
                    target_depth_pts_in_img_target[0, idx] = x
                    target_depth_pts_in_img_target[1, idx] = y
                    idx += 1

                for x, y in zip(annotations[collection_key]['left']['ixs'], annotations[collection_key]['left']['iys']):
                    target_depth_pts_in_img_target[0, idx] = x
                    target_depth_pts_in_img_target[1, idx] = y
                    idx += 1

                if len(target_depth_pts_in_img_target[0]) == 0:
                    print('No Depth point mapped into the image for collection ' + str(collection_key))
                    collections_to_skip.append(collection_key)
                    continue


            delta_pts = []
            distances = []
            w, h = collection['data'][depth_sensor_target]['width'], collection['data'][depth_sensor_target]['height']

            for idx in range(source_depth_pts_in_img_target.shape[1]):
                x = source_depth_pts_in_img_target[0][idx]
                y = source_depth_pts_in_img_target[1][idx]
                
                # If the points are near or surpassing the image limits, do not count them for the errors
                if x > w * (1 - border_tolerance) or x < w * border_tolerance or \
                    y > h * (1 - border_tolerance) or y < h * border_tolerance:
                    continue

                depth_target_pts = np.reshape(source_depth_pts_in_img_target[0:2, idx], (1, 2))
                delta_pts.append(np.min(distance.cdist(depth_target_pts, target_depth_pts_in_img_target.transpose()[:, :2], 'euclidean')))
                coords = np.where(distance.cdist(depth_target_pts, target_depth_pts_in_img_target.transpose()[:, :2], 'euclidean') == np.min(
                    distance.cdist(depth_target_pts, target_depth_pts_in_img_target.transpose()[:, :2], 'euclidean')))
                min_dist_pt = target_depth_pts_in_img_target.transpose()[coords[1]][0]
                dist = abs(depth_target_pts - min_dist_pt)
                distances.append(dist)
                delta_total.append(dist)

                if show_images:
                    image = cv2.line(image, (int(depth_target_pts.transpose()[0]), int(depth_target_pts.transpose()[1])),
                                    (int(min_dist_pt[0]), int(min_dist_pt[1])), (0, 255, 255), 3)

            if len(delta_pts) == 0:
                print('No target depth point mapped into the image for collection ' + str(collection_key))
                collections_to_skip.append(collection_key)
                continue
            # ---------------------------------------
            # --- Compute error metrics
            # ---------------------------------------
            total_pts = len(delta_pts)
            delta_pts = np.array(delta_pts, np.float32)
            avg_error = np.sum(np.abs(delta_pts)) / total_pts
            rms = np.sqrt((delta_pts ** 2).mean())

            delta_xy = np.array(distances, np.float32)
            delta_xy = delta_xy[:, 0]
            avg_error_x = np.sum(np.abs(delta_xy[:, 0])) / total_pts
            avg_error_y = np.sum(np.abs(delta_xy[:, 1])) / total_pts
            stdev_xy = np.std(delta_xy, axis=0)

            # Recording the errors
            e[collection_key]['rms'] = rms
            e[collection_key]['avg_error_x'] = avg_error_x
            e[collection_key]['avg_error_y'] = avg_error_y
            e[collection_key]['stdev_x'] = stdev_xy[0]
            e[collection_key]['stdev_y'] = stdev_xy[1]

            # ---------------------------------------
            # --- Drawing ...
            # ---------------------------------------
            if show_images is True:
                for idx in range(0, source_depth_pts_in_img_target.shape[1]):
                    image = cv2.circle(image, (int(source_depth_pts_in_img_target[0, idx]), int(
                        source_depth_pts_in_img_target[1, idx])), 5, (255, 0, 0), -1)
                for idx in range(0, target_depth_pts_in_img_target.shape[1]):
                    image = cv2.circle(image, (int(target_depth_pts_in_img_target[0, idx]), int(
                        target_depth_pts_in_img_target[1, idx])), 5, (0, 0, 255), -1)
                cv2.imshow("Depth to Depth reprojection - collection " + str(collection_key), image)
                cv2.waitKey()

        # -------------------------------------------------------------
        # Print output table
        # -------------------------------------------------------------
        table_header = ['Collection #', 'RMS (pix)', 'X err (pix)', 'Y err (pix)', 'X StDev (pix)', 'Y StDev (pix)']
        table = PrettyTable(table_header)
        table_to_save = PrettyTable(table_header) # table to save. This table was created, because the original has colors and the output csv save them as random characters

        od = OrderedDict(sorted(mixed_dataset['collections'].items(), key=lambda t: int(t[0])))
        for collection_key, collection in od.items():
            if collection_key in collections_to_skip:
                continue
            row = [collection_key,
                '%.4f' % e[collection_key]['rms'],
                '%.4f' % e[collection_key]['avg_error_x'],
                '%.4f' % e[collection_key]['avg_error_y'],
                '%.4f' % e[collection_key]['stdev_x'],
                '%.4f' % e[collection_key]['stdev_y']]

            table.add_row(row)
            table_to_save.add_row(row)

        # Compute averages and add a bottom row
        bottom_row = []  # Compute averages and add bottom row to table
        bottom_row_save = []
        for col_idx, _ in enumerate(table_header):
            if col_idx == 0:
                bottom_row.append(Fore.BLUE + Style.BRIGHT + 'Averages' + Style.RESET_ALL)
                bottom_row_save.append('Averages')
                continue

            total = 0
            count = 0
            for row in table.rows:
                # if row[col_idx].isnumeric():
                try:
                    value = float(row[col_idx])
                    total += float(value)
                    count += 1
                except:
                    pass

            value = '%.4f' % (total / count)
            bottom_row.append(Fore.BLUE + value + Style.RESET_ALL)
            bottom_row_save.append(value)

        table.add_row(bottom_row)
        table_to_save.add_row(bottom_row_save)

        # Put larger errors in red per column (per sensor)
        for col_idx, _ in enumerate(table_header):
            if col_idx == 0:  # nothing to do
                continue

            max = 0
            max_row_idx = 0
            for row_idx, row in enumerate(table.rows[: -1]):  # ignore bottom row
                try:
                    value = float(row[col_idx])
                except:
                    continue

                if value > max:
                    max = value
                    max_row_idx = row_idx

            # set the max column value to red
            table.rows[max_row_idx][col_idx] = Fore.RED + table.rows[max_row_idx][col_idx] + Style.RESET_ALL

        table.align = 'c'
        table_to_save.align = 'c'
        print(Style.BRIGHT + 'Errors per collection' + Style.RESET_ALL)
        print(table)

        # save results in csv file 
        if args['save_file_results']: 
            if args['save_file_results_name'] is None:
                results_name = f'{args["sensor_source"]}_to_{args["sensor_target"]}_{pattern_key}_results.csv'
                saveFileResults(args['train_json_file'], args['test_json_file'], results_name, table_to_save)
            else: 
                with open(args['save_file_results_name'], 'w', newline='') as f_output:
                    f_output.write(table_to_save.get_csv_string())

    print('Ending script...')
    sys.exit()