#!/usr/bin/env python3

"""
Reads the calibration results from a json file and computes the evaluation metrics
"""

# -------------------------------------------------------------------------------
# --- IMPORTS
# -------------------------------------------------------------------------------

# Standard imports
import json
import os
import argparse
import sys
from copy import deepcopy
from collections import OrderedDict

import numpy as np
from prettytable import PrettyTable
from scipy.spatial import distance
from colorama import Style, Fore

# ROS imports
from rospy_message_converter import message_converter
from atom_core.ros_numpy import numpify

# Atom imports
from atom_core.atom import getTransform
from atom_core.dataset_io import getMixedDataset, getPointCloudMessageFromDictionary, read_pcd, loadResultsJSON, filterCollectionsFromDataset
from atom_core.utilities import saveFileResults

# -------------------------------------------------------------------------------
# --- FUNCTIONS
# -------------------------------------------------------------------------------


def rangeToImage(collection, json_file, ss):
    filename = os.path.dirname(json_file) + '/' + collection['data'][ss]['data_file']
    msg = read_pcd(filename)
    collection['data'][ss].update(message_converter.convert_ros_message_to_dictionary(msg))

    cloud_msg = getPointCloudMessageFromDictionary(collection['data'][ss])
    # idxs = collection['labels'][ss]['idxs_limit_points']
    idxs = collection['labels'][ss]['idxs']

    pc = numpify(cloud_msg)[idxs]
    points_in_vel = np.zeros((4, pc.shape[0]))
    points_in_vel[0, :] = pc['x']
    points_in_vel[1, :] = pc['y']
    points_in_vel[2, :] = pc['z']
    points_in_vel[3, :] = 1

    # points_in_cam = np.dot(tf, points_in_vel)

    # # -- Project them to the image
    # w, h = collection['data'][ts]['width'], collection['data'][ts]['height']
    # K = np.ndarray((3, 3), buffer=np.array(test_dataset['sensors'][ts]['camera_info']['K']), dtype=float)
    # D = np.ndarray((5, 1), buffer=np.array(test_dataset['sensors'][ts]['camera_info']['D']), dtype=float)
    #
    # lidar_pts_in_img, _, _ = projectToCamera(K, D, w, h, points_in_cam[0:3, :])

    return points_in_vel


# -------------------------------------------------------------------------------
# --- MAIN
# -------------------------------------------------------------------------------

if __name__ == "__main__":
    ap = argparse.ArgumentParser()
    ap.add_argument("-train_json", "--train_json_file", help="Json file containing input training dataset.", type=str,
                    required=True)
    ap.add_argument("-test_json", "--test_json_file", help="Json file containing input testing dataset.", type=str,
                    required=True)
    ap.add_argument("-ss", "--sensor_source", help="Source transformation sensor.", type=str, required=True)
    ap.add_argument("-st", "--sensor_target", help="Target transformation sensor.", type=str, required=True)
    ap.add_argument("-csf", "--collection_selection_function", default=None, type=lambda s: eval(s, globals()),
                    help="A string to be evaluated into a lambda function that receives a collection name as input and "
                    "returns True or False to indicate if the collection should be loaded (and used in the "
                    "optimization). The Syntax is lambda name: f(x), where f(x) is the function in python "
                    "language. Example: lambda name: int(name) > 5 , to load only collections 6, 7, and onward.")
    ap.add_argument("-uic", "--use_incomplete_collections", action="store_true", default=False,
                    help="Remove any collection which does not have a detection for all sensors.", )
    ap.add_argument("-rpd", "--remove_partial_detections", help="Remove detected labels which are only partial."
                            "Used or the Charuco.", action="store_true", default=False)
    # ap.add_argument("-si", "--show_images", help="If true the script shows images.", action='store_true', default=False)
    ap.add_argument("-pn", "--pattern_name", help="Name of the pattern for which the evaluation will be performed", type=str, default='')
    
    # save results in a csv file
    ap.add_argument("-sfr", "--save_file_results", help="Store the results", action='store_true', default=False)
    ap.add_argument("-sfrn", "--save_file_results_name", help="Name of csv file to save the results. "
                   "Default: -test_json/results/{name_of_dataset}_{sensor_source}_to_{sensor_target}_results.csv", type=str, required=False)

    # - Save args
    args = vars(ap.parse_known_args()[0])
    sensor_source = args['sensor_source']
    sensor_target = args['sensor_target']
    # show_images = args['show_images']

    # ---------------------------------------
    # --- INITIALIZATION Read calibration data from file
    # ---------------------------------------
    # Loads a json file containing the calibration
    train_dataset, train_json_file = loadResultsJSON(args["train_json_file"], args["collection_selection_function"])

    # Loads the test json file containing a set of collections to evaluate the calibration
    test_dataset, test_json_file = loadResultsJSON(args["test_json_file"], args["collection_selection_function"])

    # ---------------------------------------
    # --- Filter some collections and / or sensors from the dataset
    # ---------------------------------------
    test_dataset = filterCollectionsFromDataset(test_dataset, args)  # filter collections

    # --- Get mixed json (calibrated transforms from train and the rest from test)
    original_mixed_dataset = getMixedDataset(train_dataset, test_dataset)

    source_frame = original_mixed_dataset['calibration_config']['sensors'][sensor_source]['link']
    target_frame = original_mixed_dataset['calibration_config']['sensors'][sensor_target]['link']

    # Patterns to evaluate
    if args['pattern_name'] == '':
        patterns_to_evaluate = original_mixed_dataset['calibration_config']['calibration_patterns'].keys()
    else:
        patterns_to_evaluate = [args['pattern_name']]

    print(Fore.BLUE + '\nStarting evalutation...' + Style.RESET_ALL)

    for pattern_key in patterns_to_evaluate:
        mixed_dataset = deepcopy(original_mixed_dataset)
        # Declare output dict to save the evaluation data if desire
        delta_total = []
        output_dict = {}
        output_dict['ground_truth_pts'] = {}

        # Deleting collections where the pattern is not found by all sensors:
        collections_to_delete = []
        for collection_key, collection in mixed_dataset['collections'].items():
            for sensor_key, sensor in mixed_dataset['sensors'].items():
                if (not collection['labels'][pattern_key][sensor_key]['detected'] or len(collection['labels'][pattern_key][sensor_key]['idxs']) == 0) and (
                        sensor_key == args['sensor_source'] or sensor_key == args['sensor_target']):
                    print(
                        Fore.RED + "Removing collection " + collection_key + ' -> pattern was not found in sensor ' +
                        sensor_key + ' (must be found in all sensors).' + Style.RESET_ALL)

                    collections_to_delete.append(collection_key)
                    break

        for collection_key in collections_to_delete:
            del mixed_dataset['collections'][collection_key]


        od = OrderedDict(sorted(mixed_dataset['collections'].items(), key=lambda t: int(t[0])))
        e = {}  # dictionary with all the errors
        collections_to_skip = []
        for collection_key, collection in od.items():
            e[collection_key] = {}  # init the dictionary of errors for this collection

            # ---------------------------------------
            # --- Range to image projection
            # ---------------------------------------
            lidar_target_to_lidar_source = getTransform(target_frame, source_frame,
                                                        mixed_dataset['collections'][collection_key]['transforms'])
            lidar_pts_source = rangeToImage(collection, test_json_file, pattern_key, sensor_source)
            lidar_pts_target = rangeToImage(collection, test_json_file, pattern_key, sensor_target)
            lidar_pts_source_in_lidar_target = np.dot(lidar_target_to_lidar_source, lidar_pts_source)
            # ---------------------------------------
            # --- Get evaluation data for current collection
            # ---------------------------------------
            delta_pts = []
            distances = []
            for idx in range(lidar_pts_target.shape[1]):
                lidar_pt = np.reshape(lidar_pts_target[0:3, idx], (1, 3))
                dist = distance.cdist(lidar_pt, lidar_pts_source_in_lidar_target.transpose()[:, :3], 'euclidean')
                delta_pts.append((np.min(dist)))
                coords = np.where(dist == np.min(dist))

                min_dist_pt = lidar_pts_source_in_lidar_target.transpose()[coords[1]][0, :3]
                dist = abs(lidar_pt - min_dist_pt)
                distances.append(dist)
                delta_total.append(dist)


            if len(delta_pts) == 0:
                print('No LiDAR point mapped into the image for collection ' + str(collection_key))
                collections_to_skip.append(collection_key)
                continue

            # ---------------------------------------
            # --- Compute error metrics
            # ---------------------------------------
            total_pts = len(delta_pts)
            delta_pts = np.array(delta_pts, np.float32)
            # avg_error = np.sum(np.abs(delta_pts)) / total_pts
            rms = np.sqrt((delta_pts ** 2).mean())

            delta_xy = np.array(distances, np.float32)
            delta_xy = delta_xy[:, 0]
            avg_error_x = np.sum(np.abs(delta_xy[:, 0])) / total_pts
            avg_error_y = np.sum(np.abs(delta_xy[:, 1])) / total_pts
            stdev_xy = np.std(delta_xy, axis=0)
            

            # Recording the errors
            e[collection_key]['rms'] = rms
            # e[collection_key]['avg_error'] = avg_error
            e[collection_key]['avg_error_x'] = avg_error_x
            e[collection_key]['avg_error_y'] = avg_error_y
            e[collection_key]['stdev_x'] = stdev_xy[0]
            e[collection_key]['stdev_y'] = stdev_xy[1]
    
    
        # -------------------------------------------------------------
        # Print output table
        # -------------------------------------------------------------
        table_header = ['Collection #', 'RMS (m)', 'X err (m)', 'Y err (m)', 'X StDev (m)', 'Y StDev (m)']
        table = PrettyTable(table_header)
        table_to_save = PrettyTable(table_header) # table to save. This table was created, because the original has colors and the output csv save them as random characters

        od = OrderedDict(sorted(mixed_dataset['collections'].items(), key=lambda t: int(t[0])))
        for collection_key, collection in od.items():
            if collection_key in collections_to_skip:
                continue
            row = [collection_key,
                '%.4f' % e[collection_key]['rms'],
                '%.4f' % e[collection_key]['avg_error_x'],
                '%.4f' % e[collection_key]['avg_error_y'],
                '%.4f' % e[collection_key]['stdev_x'],
                '%.4f' % e[collection_key]['stdev_y']]

            table.add_row(row)
            table_to_save.add_row(row)

        # Compute averages and add a bottom row
        bottom_row = []  # Compute averages and add bottom row to table
        bottom_row_save = []
        for col_idx, _ in enumerate(table_header):
            if col_idx == 0:
                bottom_row.append(Fore.BLUE + Style.BRIGHT + 'Averages' + Style.RESET_ALL)
                bottom_row_save.append('Averages')
                continue

            total = 0
            count = 0
            for row in table.rows:
                # if row[col_idx].isnumeric():
                try:
                    value = float(row[col_idx])
                    total += float(value)
                    count += 1
                except:
                    pass

            value = '%.4f' % (total / count)
            bottom_row.append(Fore.BLUE + value + Style.RESET_ALL)
            bottom_row_save.append(value)

        table.add_row(bottom_row)
        table_to_save.add_row(bottom_row_save)

        # Put larger errors in red per column (per sensor)
        for col_idx, _ in enumerate(table_header):
            if col_idx == 0:  # nothing to do
                continue

            max = 0
            max_row_idx = 0
            for row_idx, row in enumerate(table.rows[: -1]):  # ignore bottom row
                try:
                    value = float(row[col_idx])
                except:
                    continue

                if value > max:
                    max = value
                    max_row_idx = row_idx

            # set the max column value to red
            table.rows[max_row_idx][col_idx] = Fore.RED + table.rows[max_row_idx][col_idx] + Style.RESET_ALL

        table.align = 'c'
        table_to_save.align = 'c'
        print(Style.BRIGHT + 'Errors per collection' + Style.RESET_ALL)
        print(table)

        # save results in csv file 
        if args['save_file_results']: 
            if args['save_file_results_name'] is None:
                results_name = f'{args["sensor_source"]}_to_{args["sensor_target"]}_results.csv'
                saveFileResults(args['train_json_file'], args['test_json_file'], results_name, table_to_save)
            else: 
                with open(args['save_file_results_name'], 'w', newline='') as f_output:
                    f_output.write(table_to_save.get_csv_string())

    print('Ending script...')
    sys.exit()
