#!/usr/bin/env python3

"""
Reads the calibration results from a json file and computes the evaluation metrics
"""

# -------------------------------------------------------------------------------
# --- IMPORTS
# -------------------------------------------------------------------------------

# Standard imports
import json
import os
import argparse
import sys
from collections import OrderedDict

import numpy as np
import cv2
from scipy.spatial import distance
from colorama import Fore, Style
from prettytable import PrettyTable
from copy import deepcopy

# Atom imports
from atom_core.atom import getTransform
from atom_core.dataset_io import getMixedDataset, loadResultsJSON, filterCollectionsFromDataset, readAnnotationFile
from atom_core.utilities import saveFileResults
from atom_core.vision import depthInImage, rangeToImage

# -------------------------------------------------------------------------------
# --- MAIN
# -------------------------------------------------------------------------------

if __name__ == "__main__":
    ap = argparse.ArgumentParser()
    ap.add_argument("-train_json", "--train_json_file", help="Json file containing input training dataset.", type=str,
                    required=True)
    ap.add_argument("-test_json", "--test_json_file", help="Json file containing input testing dataset.", type=str,
                    required=True)
    ap.add_argument("-rs", "--range_sensor", help="Source transformation sensor.", type=str, required=True)
    ap.add_argument("-ds", "--depth_sensor", help="Target transformation sensor.", type=str, required=True)
    ap.add_argument("-si", "--show_images", help="If true the script shows images.", action='store_true', default=False)
    ap.add_argument("-bt", "--border_tolerance", help="Define the percentage of pixels to use to create a border. Lidar points outside that border will not count for the error calculations",
                                                 type=float, default=0.025)
    ap.add_argument("-ua", "--use_annotations", help="If true the script uses depth annotations file.", action='store_true', default=False)
    ap.add_argument("-csf", "--collection_selection_function", default=None, type=lambda s: eval(s, globals()),
                    help="A string to be evaluated into a lambda function that receives a collection name as input and "
                    "returns True or False to indicate if the collection should be loaded (and used in the "
                    "optimization). The Syntax is lambda name: f(x), where f(x) is the function in python "
                    "language. Example: lambda name: int(name) > 5 , to load only collections 6, 7, and onward.")
    ap.add_argument("-uic", "--use_incomplete_collections", action="store_true", default=False,
                    help="Remove any collection which does not have a detection for all sensors.", )
    ap.add_argument("-rpd", "--remove_partial_detections", help="Remove detected labels which are only partial."
                            "Used or the Charuco.", action="store_true", default=False)
    ap.add_argument("-pn", "--pattern_name", help="Name of the pattern for which the evaluation will be performed", type=str, default='')
    # save results in a csv file
    ap.add_argument("-sfr", "--save_file_results", help="Store the results", action='store_true', default=False)
    ap.add_argument("-sfrn", "--save_file_results_name", help="Name of csv file to save the results. "
                   "Default: -test_json/results/{name_of_dataset}_{sensor_source}_to_{sensor_target}_results.csv", type=str, required=False)

    # - Save args
    args = vars(ap.parse_known_args()[0])
    range_sensor = args['range_sensor']
    depth_sensor = args['depth_sensor']
    show_images = args['show_images']
    border_tolerance = args['border_tolerance']

    # ---------------------------------------
    # --- INITIALIZATION Read calibration data from file
    # ---------------------------------------
    # Loads the train json file containing the calibration results
    train_dataset, train_json_file = loadResultsJSON(args["train_json_file"], args["collection_selection_function"])

    # Loads the test json file containing a set of collections to evaluate the calibration
    test_dataset, test_json_file = loadResultsJSON(args["test_json_file"], args["collection_selection_function"])

    # ---------------------------------------
    # --- Filter some collections and / or sensors from the dataset
    # ---------------------------------------
    test_dataset = filterCollectionsFromDataset(test_dataset, args)  # filter collections


    # ---------------------------------------
    # --- Get mixed json (calibrated transforms from train and the rest from test)
    # ---------------------------------------
    original_mixed_dataset = getMixedDataset(train_dataset, test_dataset)

    source_frame = original_mixed_dataset['calibration_config']['sensors'][range_sensor]['link']
    target_frame = original_mixed_dataset['calibration_config']['sensors'][depth_sensor]['link']

    # Patterns to evaluate
    if args['pattern_name'] == '':
        patterns_to_evaluate = original_mixed_dataset['calibration_config']['calibration_patterns'].keys()
    else:
        patterns_to_evaluate = [args['pattern_name']]

    # ---------------------------------------
    # --- INITIALIZATION Read evaluation data from file ---> if desired <---
    # ---------------------------------------
    print(Fore.BLUE + '\nStarting evalutation...' + Style.RESET_ALL)

    if args['use_annotations']:
        # print(annotations)
        # print(annotations_file)
        annotations, annotations_file = readAnnotationFile(args['test_json_file'], args['depth_sensor'])

    for pattern_key in patterns_to_evaluate:
        mixed_dataset = deepcopy(original_mixed_dataset)

        # Declare output dict to save the evaluation data if desire
        delta_total = []
        rms_total=[]
        delta_xy_total=[]
        avg_error_total=[]
        avg_error_x_total = []
        avg_error_y_total = []
        stdev_xy_0_total = []
        stdev_xy_1_total = []
        n_collections=0
        output_dict = {}
        output_dict['ground_truth_pts'] = {}

        od = OrderedDict(sorted(mixed_dataset['collections'].items(), key=lambda t: int(t[0])))
        e = {}  # dictionary with all the errors
        collections_to_skip = []
        for collection_key, collection in od.items():
            e[collection_key] = {}  # init the dictionary of errors for this collection
            # ---------------------------------------
            # --- Range to depth projection
            # ---------------------------------------
            lidar2depth = getTransform(target_frame, source_frame, mixed_dataset['collections'][collection_key]['transforms'])
            lidar_pts_in_img = rangeToImage(mixed_dataset, collection, test_json_file, pattern_key, range_sensor, depth_sensor, lidar2depth)

            # ---------------------------------------
            # --- Get evaluation data for current collection
            # ---------------------------------------
            filename = os.path.dirname(test_json_file) + '/' + collection['data'][depth_sensor]['data_file']
            image = cv2.imread(filename)
            depth_pts_in_depth_img = depthInImage(mixed_dataset, collection_key, pattern_key, depth_sensor)

            if args['use_annotations']:
                num_points = 0
                num_points += len(annotations[collection_key]['top']['ixs'])
                num_points += len(annotations[collection_key]['right']['ixs'])
                num_points += len(annotations[collection_key]['bottom']['ixs'])
                num_points += len(annotations[collection_key]['left']['ixs'])

                depth_pts_in_depth_img = np.ndarray((2,num_points), dtype=float)
                idx = 0
                for x,y in zip(annotations[collection_key]['top']['ixs'], annotations[collection_key]['top']['iys']):
                    depth_pts_in_depth_img[0,idx] = x
                    depth_pts_in_depth_img[1,idx] = y
                    idx += 1

                for x,y in zip(annotations[collection_key]['right']['ixs'], annotations[collection_key]['right']['iys']):
                    depth_pts_in_depth_img[0,idx] = x
                    depth_pts_in_depth_img[1,idx] = y
                    idx += 1

                for x,y in zip(annotations[collection_key]['bottom']['ixs'], annotations[collection_key]['bottom']['iys']):
                    depth_pts_in_depth_img[0,idx] = x
                    depth_pts_in_depth_img[1,idx] = y
                    idx += 1

                for x,y in zip(annotations[collection_key]['left']['ixs'], annotations[collection_key]['left']['iys']):
                    depth_pts_in_depth_img[0,idx] = x
                    depth_pts_in_depth_img[1,idx] = y
                    idx += 1

                if len(depth_pts_in_depth_img[0]) == 0:
                    print('No Depth point mapped into the image for collection ' + str(collection_key))
                    collections_to_skip.append(collection_key)
                    continue

            # Clear image annotations
            image = cv2.imread(filename)

            delta_pts = []
            distances = []
            w, h = collection['data'][depth_sensor]['width'], collection['data'][depth_sensor]['height']
            for idx in range(lidar_pts_in_img.shape[1]):
                x = lidar_pts_in_img[0][idx]
                y = lidar_pts_in_img[1][idx]
                
                # If the points are near or surpassing the image limits, do not count them for the errors
                if x > w * (1 - border_tolerance) or x < w * border_tolerance or \
                    y > h * (1 - border_tolerance) or y < h * border_tolerance:
                    continue

                lidar_pt = np.reshape(lidar_pts_in_img[0:2, idx], (1, 2))
                delta_pts.append(np.min(distance.cdist(lidar_pt, depth_pts_in_depth_img.transpose()[:, :2], 'euclidean')))
                coords = np.where(distance.cdist(lidar_pt, depth_pts_in_depth_img.transpose()[:, :2], 'euclidean') == np.min(
                    distance.cdist(lidar_pt, depth_pts_in_depth_img.transpose()[:, :2], 'euclidean')))
                min_dist_pt = depth_pts_in_depth_img.transpose()[coords[1]][0]
                # print(min_dist_pt)
                dist = abs(lidar_pt - min_dist_pt)
                distances.append(dist)
                delta_total.append(dist)

                if show_images:
                    image = cv2.line(image, (int(lidar_pt.transpose()[0]), int(lidar_pt.transpose()[1])),
                                    (int(min_dist_pt[0]), int(min_dist_pt[1])), (0, 255, 255), 2)

            if len(delta_pts) == 0:
                print('No LiDAR point mapped into the image for collection ' + str(collection_key))
                collections_to_skip.append(collection_key)
                continue

            # ---------------------------------------
            # --- Compute error metrics
            # ---------------------------------------
            total_pts = len(delta_pts)
            delta_pts = np.array(delta_pts, np.float32)
            rms = np.sqrt((delta_pts ** 2).mean())
            rms_total.append(rms)

            delta_xy = np.array(distances, np.float32)
            delta_xy = delta_xy[:, 0]
            avg_error_x = np.sum(np.abs(delta_xy[:, 0])) / total_pts
            avg_error_y = np.sum(np.abs(delta_xy[:, 1])) / total_pts
            stdev_xy = np.std(delta_xy, axis=0)


            # Recording the errors
            e[collection_key]['rms'] = rms
            e[collection_key]['avg_error_x'] = avg_error_x
            e[collection_key]['avg_error_y'] = avg_error_y
            e[collection_key]['stdev_x'] = stdev_xy[0]
            e[collection_key]['stdev_y'] = stdev_xy[1]


            # ---------------------------------------
            # --- Drawing ...
            # ---------------------------------------
            if show_images is True:

                for idx in range(0, lidar_pts_in_img.shape[1]):
                    image = cv2.circle(image, (int(lidar_pts_in_img[0, idx]), int(
                        lidar_pts_in_img[1, idx])), 5, (255, 0, 0), -1)
                for idx in range(0, depth_pts_in_depth_img.shape[1]):
                    image = cv2.circle(image, (int(depth_pts_in_depth_img[0, idx]), int(
                        depth_pts_in_depth_img[1, idx])), 5, (0, 0, 255), -1)
                win_name = "Lidar to Depth reprojection - collection " + str(collection_key)
                cv2.imshow(win_name, image)
                cv2.waitKey()
                cv2.destroyWindow(winname=win_name)

        # -------------------------------------------------------------
        # Print output table
        # -------------------------------------------------------------
        table_header = ['Collection #', 'RMS (pix)', 'X err (pix)', 'Y err (pix)', 'X StDev (pix)', 'Y StDev (pix)']
        table = PrettyTable(table_header)
        table_to_save = PrettyTable(table_header) # table to save. This table was created, because the original has colors and the output csv save them as random characters

        od = OrderedDict(sorted(mixed_dataset['collections'].items(), key=lambda t: int(t[0])))
        for collection_key, collection in od.items():
            if collection_key in collections_to_skip:
                continue
            row = [collection_key,
                '%.4f' % e[collection_key]['rms'],
                '%.4f' % e[collection_key]['avg_error_x'],
                '%.4f' % e[collection_key]['avg_error_y'],
                '%.4f' % e[collection_key]['stdev_x'],
                '%.4f' % e[collection_key]['stdev_y']]

            table.add_row(row)
            table_to_save.add_row(row)

        # Compute averages and add a bottom row
        bottom_row = []  # Compute averages and add bottom row to table
        bottom_row_save = []
        for col_idx, _ in enumerate(table_header):
            if col_idx == 0:
                bottom_row.append(Fore.BLUE + Style.BRIGHT + 'Averages' + Style.RESET_ALL)
                bottom_row_save.append('Averages')
                continue

            total = 0
            count = 0
            for row in table.rows:
                # if row[col_idx].isnumeric():
                try:
                    value = float(row[col_idx])
                    total += float(value)
                    count += 1
                except:
                    pass

            value = '%.4f' % (total / count)
            bottom_row.append(Fore.BLUE + value + Style.RESET_ALL)
            bottom_row_save.append(value)

        table.add_row(bottom_row)
        table_to_save.add_row(bottom_row_save)

        # Put larger errors in red per column (per sensor)
        for col_idx, _ in enumerate(table_header):
            if col_idx == 0:  # nothing to do
                continue

            max = 0
            max_row_idx = 0
            for row_idx, row in enumerate(table.rows[: -1]):  # ignore bottom row
                try:
                    value = float(row[col_idx])
                except:
                    continue


                if value > max:
                    max = value
                    max_row_idx = row_idx

            # set the max column value to red
            table.rows[max_row_idx][col_idx] = Fore.RED + table.rows[max_row_idx][col_idx] + Style.RESET_ALL

        table.align = 'c'
        table_to_save.align = 'c'
        print(Style.BRIGHT + 'Errors per collection' + Style.RESET_ALL)
        print(table)

        # save results in csv file 
        if args['save_file_results']: 
            if args['save_file_results_name'] is None:
                results_name = f'{args["range_sensor"]}_to_{args["depth_sensor"]}_results.csv'
                saveFileResults(args['train_json_file'], args['test_json_file'], results_name, table_to_save)
            else: 
                with open(args['save_file_results_name'], 'w', newline='') as f_output:
                    f_output.write(table_to_save.get_csv_string())

    print('Ending script...')
    sys.exit()