#!/usr/bin/env python3

"""
Reads the calibration results from a json file and computes the evaluation metrics
"""

# -------------------------------------------------------------------------------
# --- IMPORTS
# -------------------------------------------------------------------------------

# Standard imports
import json
import math
import os
import argparse
import sys
from collections import OrderedDict

import numpy as np
import atom_core.ros_numpy
import cv2
from prettytable import PrettyTable
from colorama import Style, Fore
from copy import deepcopy

# ROS imports
from rospy_message_converter import message_converter

# Atom imports
from atom_core.atom import getTransform
from atom_core.dataset_io import getMixedDataset, getPointCloudMessageFromDictionary, read_pcd, readAnnotationFile, loadResultsJSON, filterCollectionsFromDataset
from atom_core.utilities import rootMeanSquare, saveFileResults
from atom_core.vision import projectToCamera

# -------------------------------------------------------------------------------
# --- FUNCTIONS
# -------------------------------------------------------------------------------


def rangeToImage(collection, json_file, rc, cs, tf,pattern_key):

    '''
    rs -> ranging sensor
    cs -> camera sensor
    tf -> a transformation from the ranging to the imaging sensor
    '''

    filename = os.path.dirname(json_file) + '/' + collection['data'][rc]['data_file']
    msg = read_pcd(filename)
    collection['data'][rc].update(message_converter.convert_ros_message_to_dictionary(msg))

    cloud_msg = getPointCloudMessageFromDictionary(collection['data'][rc])

    idxs = collection['labels'][pattern_key][rc]['idxs_limit_points']


    pc = atom_core.ros_numpy.numpify(cloud_msg)[idxs]
    points_in_vel = np.zeros((4, pc.shape[0]))
    points_in_vel[0, :] = pc['x']
    points_in_vel[1, :] = pc['y']
    points_in_vel[2, :] = pc['z']
    points_in_vel[3, :] = 1

    points_in_cam = np.dot(tf, points_in_vel)

    # -- Project them to the image
    w, h = collection['data'][cs]['width'], collection['data'][cs]['height']
    K = np.ndarray((3, 3), buffer=np.array(mixed_dataset['sensors'][cs]['camera_info']['K']), dtype=float)
    D = np.ndarray((5, 1), buffer=np.array(mixed_dataset['sensors'][cs]['camera_info']['D']), dtype=float)

    pts_in_image, _, _ = projectToCamera(K, D, w, h, points_in_cam[0:3, :])

    return pts_in_image

# -------------------------------------------------------------------------------
# --- MAIN
# -------------------------------------------------------------------------------


if __name__ == "__main__":

    # ---------------------------------------
    # --- Read commmand line arguments
    # ---------------------------------------

    ap = argparse.ArgumentParser()
    ap.add_argument("-train_json", "--train_json_file", help="Json file containing input training dataset.", type=str,
                    required=True)
    ap.add_argument("-test_json", "--test_json_file", help="Json file containing input testing dataset.", type=str,
                    required=True)
    ap.add_argument("-rs", "--range_sensor", help="Source transformation sensor.", type=str, required=True)
    ap.add_argument("-cs", "--camera_sensor", help="Target transformation sensor.", type=str, required=True)
    ap.add_argument("-si", "--show_images", help="If true the script shows images.", action='store_true', default=False)
    ap.add_argument("-csf", "--collection_selection_function", default=None, type=lambda s: eval(s, globals()),
                    help="A string to be evaluated into a lambda function that receives a collection name as input and "
                    "returns True or False to indicate if the collection should be loaded (and used in the "
                    "optimization). The Syntax is lambda name: f(x), where f(x) is the function in python "
                    "language. Example: lambda name: int(name) > 5 , to load only collections 6, 7, and onward.")
    ap.add_argument("-uic", "--use_incomplete_collections", action="store_true", default=False,
                    help="Remove any collection which does not have a detection for all sensors.", )
    ap.add_argument("-rpd", "--remove_partial_detections", help="Remove detected labels which are only partial."
                            "Used or the Charuco.", action="store_true", default=False)
    ap.add_argument("-pn", "--pattern_name", help="Name of the pattern for which the evaluation will be performed", type=str, default='')

    # save results in a csv file
    ap.add_argument("-sfr", "--save_file_results", help="Store the results", action='store_true', default=False)
    ap.add_argument("-sfrn", "--save_file_results_name", help="Name of csv file to save the results. "
                   "Default: -test_json/results/{name_of_dataset}_{sensor_source}_to_{sensor_target}_results.csv", type=str, required=False)

    # parse args

    args = vars(ap.parse_known_args()[0])
    show_images = args['show_images']
    # eval_file = args['eval_file']
    # use_annotation = args['use_annotation']

    # ---------------------------------------
    # --- INITIALIZATION Read calibration data from file
    # ---------------------------------------
    # Loads a json file containing the calibration
    train_dataset, train_json_file = loadResultsJSON(args["train_json_file"], args["collection_selection_function"])
    # Loads the test json file containing a set of collections to evaluate the calibration
    test_dataset, test_json_file = loadResultsJSON(args["test_json_file"], args["collection_selection_function"])

    # ---------------------------------------
    # --- Filter some collections and / or sensors from the dataset
    # ---------------------------------------
    test_dataset = filterCollectionsFromDataset(test_dataset, args)  # filter collections
    annotations, annotations_file = readAnnotationFile(args['test_json_file'], args['camera_sensor'])

    # --- Get mixed json (calibrated transforms from train and the rest from test)
    original_mixed_dataset = getMixedDataset(train_dataset, test_dataset)

    print(Fore.BLUE + '\nStarting evalutation...' + Style.RESET_ALL)

    # Patterns to evaluate
    if args['pattern_name'] == '':
        patterns_to_evaluate = original_mixed_dataset['calibration_config']['calibration_patterns'].keys()
    else:
        patterns_to_evaluate = [args['pattern_name']]

    from_frame = original_mixed_dataset['calibration_config']['sensors'][args['camera_sensor']]['link']
    to_frame = original_mixed_dataset['calibration_config']['sensors'][args['range_sensor']]['link']
    
    for pattern_key in patterns_to_evaluate:
        mixed_dataset = deepcopy(original_mixed_dataset)
        # ---------------------------------------
        # --- INITIALIZATION Read evaluation data from file ---> if desired <---
        # ---------------------------------------
        # Deleting collections where the pattern is not found by all sensors:
        collections_to_delete = []
        for collection_key, collection in mixed_dataset['collections'].items():
            for sensor_key, sensor in mixed_dataset['sensors'].items():
                if not collection['labels'][pattern_key][sensor_key]['detected'] and (
                        sensor_key == args['camera_sensor'] or sensor_key == args['range_sensor']):
                    print(
                        Fore.RED + "Removing collection " + collection_key + ' -> pattern was not found in sensor ' +
                        sensor_key + ' (must be found in all sensors).' + Style.RESET_ALL)

                    collections_to_delete.append(collection_key)
                    break

        for collection_key in collections_to_delete:
            del mixed_dataset['collections'][collection_key]

        od = OrderedDict(sorted(mixed_dataset['collections'].items(), key=lambda t: int(t[0])))
        e = {}  # dictionary with all the errors
        collections_to_skip = []

        for collection_key, collection in od.items():
            e[collection_key] = {}  # init the dictionary of errors for this collection
            # ---------------------------------------
            # --- Range to image projection
            # ---------------------------------------
            lidar2cam = getTransform(from_frame, to_frame,
                                mixed_dataset['collections'][collection_key]['transforms'])

            pts_in_image = rangeToImage(collection,
                                        args['test_json_file'],
                                        args['range_sensor'],
                                        args['camera_sensor'],
                                        lidar2cam,
                                        pattern_key)

            # ---------------------------------------
            # --- Get evaluation data for current collection
            # ---------------------------------------
            filename = os.path.dirname(args['test_json_file']
                                    ) + '/' + collection['data'][args['camera_sensor']]['data_file']
            image = cv2.imread(filename)

            if args['show_images']:  # draw all ground truth annotations
                for side in annotations[collection_key].keys():
                    for x, y in zip(annotations[collection_key][side]['ixs'],
                                    annotations[collection_key][side]['iys']):
                        cv2.circle(image, (int(round(x)), int(round(y))), 1, (0, 255, 0), -1)

            # ---------------------------------------
            # --- Evaluation metrics - reprojection error
            # ---------------------------------------
            # -- For each reprojected limit point, find the closest ground truth point and compute the distance to it
            x_errors = []
            y_errors = []
            errors = []
            for idx in range(0, pts_in_image.shape[1]):
                x_proj = pts_in_image[0, idx]
                y_proj = pts_in_image[1, idx]

                # Do not consider points that are re-projected outside of the image
                if x_proj > image.shape[1] or x_proj < 0 or y_proj > image.shape[0] or y_proj < 0:
                    continue

                min_error = sys.float_info.max  # a very large value
                x_min = None
                y_min = None
                for side in annotations[collection_key].keys():
                    for x, y in zip(annotations[collection_key][side]['ixs'],
                                    annotations[collection_key][side]['iys']):
                        error = math.sqrt((x_proj-x)**2 + (y_proj-y)**2)
                        if error < min_error:
                            min_error = error
                            x_min = x
                            y_min = y

                x_errors.append(abs(x_proj - x_min))
                y_errors.append(abs(y_proj - y_min))
                errors.append(min_error)

                if args['show_images']:
                    cv2.circle(image, (int(round(x_proj)), int(round(y_proj))), 5, (255, 0, 0), -1)
                    cv2.line(image, (int(round(x_proj)), int(round(y_proj))),
                            (int(round(x_min)), int(round(y_min))), (0, 255, 255, 1))

            if not errors:
                print('No LiDAR point mapped into the image for collection ' + str(collection_key))
                collections_to_skip.append(collection_key)
                continue

            e[collection_key]['x'] = np.average(x_errors)
            e[collection_key]['y'] = np.average(y_errors)
            e[collection_key]['rms'] = rootMeanSquare(errors)

            if args['show_images']:
                print('Errors collection ' + collection_key + '\n' + str(e[collection_key]))
                window_name = 'Collection ' + collection_key
                cv2.namedWindow(window_name, cv2.WINDOW_NORMAL)
                cv2.imshow(window_name, image)
                key = cv2.waitKey(0)
                cv2.destroyWindow(window_name)
                if key == ord('q') or key == ord('c'):
                    args['show_images'] = False

        # -------------------------------------------------------------
        # Print output table
        # -------------------------------------------------------------
        table_header = ['Collection #', 'RMS (pix)', 'X err (pix)', 'Y err (pix)']
        table = PrettyTable(table_header)
        table_to_save = PrettyTable(table_header) # table to save. This table was created, because the original has colors and the output csv save them as random characters

        od = OrderedDict(sorted(mixed_dataset['collections'].items(), key=lambda t: int(t[0])))
        for collection_key, collection in od.items():
            if collection_key in collections_to_skip:
                continue
            row = [collection_key,
                '%.4f' % e[collection_key]['rms'],
                '%.4f' % e[collection_key]['x'],
                '%.4f' % e[collection_key]['y']]

            table.add_row(row)
            table_to_save.add_row(row)

        # Compute averages and add a bottom row
        bottom_row = []  # Compute averages and add bottom row to table
        bottom_row_save = []
        for col_idx, _ in enumerate(table_header):
            if col_idx == 0:
                bottom_row.append(Fore.BLUE + Style.BRIGHT + 'Averages' + Style.RESET_ALL)
                bottom_row_save.append('Averages')
                continue

            total = 0
            count = 0
            for row in table.rows:
                # if row[col_idx].isnumeric():
                try:
                    value = float(row[col_idx])
                    total += float(value)
                    count += 1
                except:
                    pass

            value = '%.4f' % (total / count)
            bottom_row.append(Fore.BLUE + value + Style.RESET_ALL)
            bottom_row_save.append(value)

        table.add_row(bottom_row)
        table_to_save.add_row(bottom_row_save)

        # Put larger errors in red per column (per sensor)
        for col_idx, _ in enumerate(table_header):
            if col_idx == 0:  # nothing to do
                continue

            max = 0
            max_row_idx = 0
            for row_idx, row in enumerate(table.rows[: -1]):  # ignore bottom row
                try:
                    value = float(row[col_idx])
                except:
                    continue

                if value > max:
                    max = value
                    max_row_idx = row_idx

            # set the max column value to red
            table.rows[max_row_idx][col_idx] = Fore.RED + table.rows[max_row_idx][col_idx] + Style.RESET_ALL

        table.align = 'c'
        table_to_save.align = 'c'
        print(Style.BRIGHT + 'Errors per collection' + Style.RESET_ALL)
        print(table)

        # save results in csv file 
        if args['save_file_results']: 
            if args['save_file_results_name'] is None:
                results_name = f'{args["range_sensor"]}_to_{args["camera_sensor"]}_results.csv'
                saveFileResults(args['train_json_file'], args['test_json_file'], results_name, table_to_save)
            else: 
                with open(args['save_file_results_name'], 'w', newline='') as f_output:
                    f_output.write(table_to_save.get_csv_string())

    print('Ending script...')
    sys.exit()
