#!/usr/bin/env python3

"""
Reads the calibration results from two json files and computes the evaluation metrics. 
"""

# -------------------------------------------------------------------------------
# --- IMPORTS
# -------------------------------------------------------------------------------
# Standard imports
import argparse
import copy
import json
from collections import OrderedDict
import sys

# ROS imports
import numpy as np
from colorama import Style, Fore
from prettytable import PrettyTable

# Atom imports
from atom_core.atom import getTransform
from atom_core.utilities import saveFileResults
from atom_core.naming import generateKey
from atom_core.dataset_io import getMixedDataset
from atom_core.geometry import matrixToRodrigues

# -------------------------------------------------------------------------------
# --- CLASSES 
# -------------------------------------------------------------------------------
class Table:
    def __init__(self, dict, args, compute_all_trans, target_frame, source_child) -> None:
        self.dict = dict
        self.compute_all_trans = compute_all_trans
        self.arg = args
        self.target_frame = target_frame
        self.source_child = source_child

    def generateHeader(self) -> None:
        if self.compute_all_trans:
            self.header = ['Transformation #',  'Trans (m)', 'Rot (rad)',
                        'X (m)', 'Y (m)', 'Z (m)', 
                        'Roll (rad)', 'Pitch (rad)', 'Yaw (rad)']
        else:
            self.header = ['Collection #', 'Trans (m)', 'Rot (rad)',
                        'X (m)', 'Y (m)', 'Z (m)', 
                        'Roll (rad)', 'Pitch (rad)', 'Yaw (rad)']
            
            
        self.table = PrettyTable(self.header)
        self.table_to_save = PrettyTable(self.header)

    def generateTitle(self):
        if not self.compute_all_trans:
            self.table.title = f'{self.target_frame} to {self.source_child}'


    def adding_rows(self) -> None:
        if self.compute_all_trans:
            for frame_key, frame in self.dict.items():
                row = [frame_key, 
                    '%.4f' % self.dict[frame_key]['Trans'],
                    '%.4f' % self.dict[frame_key]['Rot'],
                    '%.4f' % self.dict[frame_key]['x'],
                    '%.4f' % self.dict[frame_key]['y'],
                    '%.4f' % self.dict[frame_key]['z'],
                    '%.4f' % self.dict[frame_key]['roll'],
                    '%.4f' % self.dict[frame_key]['pitch'],
                    '%.4f' % self.dict[frame_key]['yaw']
                    ]
                    
                self.table.add_row(row)
                self.table_to_save.add_row(row)
        elif args['show_all_collections'] == True:
            rows_dict = copy.deepcopy(self.dict)
            rows_dict.pop('Averages')
            for frame_key, frame in rows_dict.items():
                row = [frame_key, 
                    '%.4f' % self.dict[frame_key]['Trans'],
                    '%.4f' % self.dict[frame_key]['Rot'],
                    '%.4f' % self.dict[frame_key]['x'],
                    '%.4f' % self.dict[frame_key]['y'],
                    '%.4f' % self.dict[frame_key]['z'],
                    '%.4f' % self.dict[frame_key]['roll'],
                    '%.4f' % self.dict[frame_key]['pitch'],
                    '%.4f' % self.dict[frame_key]['yaw']
                    ]
                    
                self.table.add_row(row)
                self.table_to_save.add_row(row)

        
    def table_to_csv(self) -> None:
        if self.arg['save_file_results']:
            if args['save_file_results_name'] is None:
                results_name = 'ground_truth_frame_results.csv'
                saveFileResults(args['train_json_file'], args['test_json_file'], results_name, self.table_to_save)
            else: 
                with open(args['save_file_results_name'], 'w', newline='') as f_output:
                    f_output.write(self.table_to_save.get_csv_string())
    
    def addAverages(self) -> None:
        bottom_row = []
        bottom_row_save = []
        bottom_row.append(Fore.BLUE + Style.BRIGHT + 'Averages' + Fore.BLACK + Style.NORMAL)
        bottom_row_save.append('Averages')
        for avg_key, avg in self.dict['Averages'].items():
            value = '%.4f' % avg
            bottom_row.append(Fore.BLUE + str(value) + Fore.BLACK)
            bottom_row_save.append(value)

        self.table.add_row(bottom_row)
        self.table_to_save.add_row(bottom_row_save)

    def printTable(self) -> None:
        self.generateHeader()
        self.generateTitle()
        self.adding_rows()
        self.table.align ='c'
        self.table_to_save.align ='c'
        if not self.compute_all_trans:
            self.addAverages()
        print(Style.BRIGHT + 'Errors per frame' + Style.RESET_ALL)
        print(self.table)
        self.table_to_csv()

class ErrorsDictionary:
    def __init__(self, mixed_dataset, test_dataset, transform_key, target_frame, source_frame, source_child, dict, comput_all_trans) -> None:
        self.mixed_dataset = mixed_dataset
        self.test_dataset = test_dataset
        self.transform_key = transform_key
        self.target_frame = target_frame
        self.source_frame = source_frame
        self.source_child = source_child
        self.dict = dict
        self.comput_all_trans = comput_all_trans

        self.rpy_cal = []
        self.rpy_init = []

        self.deltaT = []
        self.deltaR = []
        
    def computeFrameAverages(self)-> None:
        self.avg_frame_trans_x = np.mean(np.array([abs(x[0]) for x in self.deltaT]))
        self.avg_frame_trans_y = np.mean(np.array([abs(y[1]) for y in self.deltaT]))
        self.avg_frame_trans_z = np.mean(np.array([abs(z[2]) for z in self.deltaT]))
        self.avg_frame_rot_r = np.mean(np.array([r[0] for r in self.deltaR]))
        self.avg_frame_rot_p = np.mean(np.array([p[1] for p in self.deltaR]))
        self.avg_frame_rot_y = np.mean(np.array([y[2] for y in self.deltaR]))

        self.average_trans = sum(self.values_per_collection[collection_key]['trans'] for collection_key in self.values_per_collection) / len(self.values_per_collection)
        self.average_rot = sum(self.values_per_collection[collection_key]['rot'] for collection_key in self.values_per_collection) / len(self.values_per_collection)

    def computeTransRot(self) -> None:
        train_od = OrderedDict(sorted(self.mixed_dataset['collections'].items(), key=lambda t: int(t[0])))
        test_od = OrderedDict(sorted(self.test_dataset['collections'].items(), key=lambda t: int(t[0])))
        self.values_per_collection = {}
        saved_collections = 0
        for collection_key, collection in test_od.items():
            self.values_per_collection[collection_key] = {}
            # here we check if the collection from the test dataset is in the train dataset
            # with atom, we may want to remove some collections during the calibration
            # this if prevent those mismatches
            if collection_key in train_od:
                train_transform = getTransform(self.target_frame, self.source_child, train_od[collection_key]['transforms'])          
                test_transform = getTransform(self.target_frame, self.source_child, collection['transforms'])
                delta = np.dot(np.linalg.inv(train_transform), test_transform)
                self.deltaT.append(delta[0:3, 3])
                self.deltaR.append(matrixToRodrigues(delta[0:3, 0:3]))
                # if len(self.deltaT) > int(collection_key):
                self.values_per_collection[collection_key]['trans'] = np.linalg.norm(self.deltaT[saved_collections])
                self.values_per_collection[collection_key]['rot'] = np.linalg.norm(self.deltaR[saved_collections])
                saved_collections +=1

        # Find empty keys {} and remove them
        for key in list(self.values_per_collection.keys()):
            if not self.values_per_collection[key]:
                del self.values_per_collection[key]

    def computeFinalValues(self) -> None:
        if self.comput_all_trans:
            self.saveAllTrans()
        else:
            self.saveSourceTrans()

    def saveAllTrans(self) -> None:
        self.dict[self.source_frame] = {} # init the dictionary of errors for this transformation
        self.dict[self.source_frame]['x'] = abs(self.avg_frame_trans_x)
        self.dict[self.source_frame]['y'] = abs(self.avg_frame_trans_y)
        self.dict[self.source_frame]['z'] = abs(self.avg_frame_trans_z)
        self.dict[self.source_frame]['roll'] = abs(self.avg_frame_rot_r)
        self.dict[self.source_frame]['pitch'] = abs(self.avg_frame_rot_p)
        self.dict[self.source_frame]['yaw'] = abs(self.avg_frame_rot_y)
        self.dict[self.source_frame]['Trans'] = self.average_trans
        self.dict[self.source_frame]['Rot'] = self.average_rot

    def saveSourceTrans(self) -> None:
        for collection_key, collection in OrderedDict(sorted(self.mixed_dataset['collections'].items(), key=lambda t: int(t[0]))).items():
            self.dict[collection_key] = {}  # init the dictionary of errors for this collection
            j=int(collection_key)
            try:
                self.dict[collection_key]['x'] = abs(self.deltaT[j][0])
                self.dict[collection_key]['y'] = abs(self.deltaT[j][1])
                self.dict[collection_key]['z'] = abs(self.deltaT[j][2])
                self.dict[collection_key]['roll'] = abs(self.deltaR[j][0])
                self.dict[collection_key]['pitch'] = abs(self.deltaR[j][1])
                self.dict[collection_key]['yaw'] = abs(self.deltaR[j][2])
                self.dict[collection_key]['Trans'] = self.values_per_collection[collection_key]['trans']
                self.dict[collection_key]['Rot'] = self.values_per_collection[collection_key]['rot']
            except:
                print(f'Collection {Fore.RED}{collection_key}{Style.RESET_ALL} discarded due to differences in datasets lengths.')
                self.dict.pop(collection_key)

        self.dict['Averages'] = {}
        self.dict['Averages']['x'] = abs(self.avg_frame_trans_x)
        self.dict['Averages']['y'] = abs(self.avg_frame_trans_y)
        self.dict['Averages']['z'] = abs(self.avg_frame_trans_z)
        self.dict['Averages']['roll'] = abs(self.avg_frame_rot_r)
        self.dict['Averages']['pitch'] = abs(self.avg_frame_rot_p)
        self.dict['Averages']['yaw'] = abs(self.avg_frame_rot_y)
        self.dict['Averages']['Trans'] = self.average_trans
        self.dict['Averages']['Rot'] = self.average_rot

    def saveToDict(self) -> None:
        self.computeTransRot()
        self.computeFrameAverages()
        self.computeFinalValues()

# -------------------------------------------------------------------------------
# --- FUNCTIONS
# -------------------------------------------------------------------------------
def evaluateTargetFrame(target_frame, source_child):
    try:
        train_transform = getTransform(target_frame, source_child, train_dataset['collections'][list(train_dataset['collections'].keys())[0]]['transforms'])
    except:
        print(f'Target frame {Fore.RED}{target_frame}{Style.RESET_ALL} is not in the transformation chain of {Fore.RED}{source_child}{Style.RESET_ALL}')
        sys.exit()

def loadArgJson(arg):
    f = open(arg, 'r')
    return json.load(f)

def argToTargetFrame(arg, train_dataset):
    if arg == None:
        # find the parent of the first frame in the first collection (e.g. base_link)
        transforms = train_dataset['collections'][list(train_dataset['collections'].keys())[0]]['transforms']
        for transform in transforms:
            if 'parent' in transforms[transform]:
                target_frame = transforms[transform]['parent']
                break
        else:
            target_frame = None
    else:
        target_frame = args['target_frame']
    
    return target_frame

def evaluateSourceFrame(source_frame, train_dataset):
    if train_dataset['calibration_config']['additional_tfs']:
        if source_frame in train_dataset['calibration_config']['additional_tfs']:
            transform_key = generateKey(train_dataset['calibration_config']['additional_tfs'][source_frame]['parent_link'], 
                                        train_dataset['calibration_config']['additional_tfs'][source_frame]['child_link'])
            source_child = train_dataset['calibration_config']['additional_tfs'][source_frame]['child_link']
        elif source_frame in train_dataset['calibration_config']['sensors']:
            transform_key = generateKey(train_dataset['calibration_config']['sensors'][source_frame]['parent_link'], 
                                        train_dataset['calibration_config']['sensors'][source_frame]['child_link'])
            source_child = train_dataset['calibration_config']['sensors'][source_frame]['child_link']
        else:
            print(f'Source frame {Fore.RED}{source_frame}{Style.RESET_ALL} in not a known frame. Use names matching the ones on the {Fore.RED}config.yaml{Style.RESET_ALL} calibration file')
            sys.exit()
    else:
        if source_frame in train_dataset['calibration_config']['sensors']:
            transform_key = generateKey(train_dataset['calibration_config']['sensors'][source_frame]['parent_link'], 
                                        train_dataset['calibration_config']['sensors'][source_frame]['child_link'])
            source_child = train_dataset['calibration_config']['sensors'][source_frame]['child_link']
        else:
            print(f'Source frame {Fore.RED}{source_frame}{Style.RESET_ALL} in not a known frame. Use names matching the ones on the {Fore.RED}config.yaml{Style.RESET_ALL} calibration file')
            sys.exit()

    return transform_key, source_child

# -------------------------------------------------------------------------------
# --- MAIN
# -------------------------------------------------------------------------------
if __name__ == "__main__":
    ap = argparse.ArgumentParser()
    ap.add_argument("-train_json", "--train_json_file", help="Json file containing train input dataset.", type=str, required=True)
    ap.add_argument("-test_json", "--test_json_file", help="Json file containing test input dataset.", type=str, required=True)
    ap.add_argument("-tf", "--target_frame", help="Target transformation frame. If no frame is provided, the first link in the calibration chain after the reference frame will be assigned.", type=str, required=False)
    ap.add_argument("-sf", "--source_frame", help="Source transformation frame. If no frame is provided, the errors will be computed for all estimated frames.", type=str, required=False)
    # save results in a csv file
    ap.add_argument("-sac", "--show_all_collections", help="In the outputted table, it displays the result for every collections, as opposed to the averages.", required=False, action="store_true", default=False)
    ap.add_argument("-sfr", "--save_file_results", help="Output folder to where the results will be stored.", type=str, required=False)
    ap.add_argument("-sfrn", "--save_file_results_name", help="Name of csv file to save the results. "
                   "Default: {name_of_dataset}_ground_truth_frame_results.csv", type=str, required=False)

    args= vars(ap.parse_known_args()[0])
    # ---------------------------------------
    # --- INITIALIZATION Read calibration data from file
    # ---------------------------------------
    train_dataset = loadArgJson(args['train_json_file'])
    test_dataset = loadArgJson(args['test_json_file'])
    # --- Get mixed json (calibrated transforms from train and the rest from test)
    mixed_dataset = getMixedDataset(train_dataset, test_dataset)

    # ---------------------------------------
    # --- STEP 1: Calculate error values and append into a dict
    # ---------------------------------------
    errors_dict = {} # dictionary with all the errors
    compute_all_transformations = True
    target_frame = argToTargetFrame(args['target_frame'],mixed_dataset)

    if args['source_frame'] != None: # compute for a selected source frame
        source_frame = args['source_frame']
        compute_all_transformations = False
        transform_key, source_child = evaluateSourceFrame(source_frame, mixed_dataset)
        evaluateTargetFrame(target_frame, source_child)
        errors = ErrorsDictionary(mixed_dataset,test_dataset,transform_key,target_frame,source_frame,source_child,errors_dict,compute_all_transformations)
        errors.saveToDict()
        
    else: # compute for all frames estimated in the calibration
        if mixed_dataset['calibration_config']['additional_tfs']:
            for additional_tf_key, additional_tf in mixed_dataset['calibration_config']['additional_tfs'].items():
                transform_key = generateKey(additional_tf['parent_link'], additional_tf['child_link'])
                source_child = additional_tf['child_link']
                evaluateTargetFrame(target_frame, source_child)
                errors = ErrorsDictionary(mixed_dataset,test_dataset,transform_key,target_frame,transform_key,source_child,errors_dict,compute_all_transformations)
                errors.saveToDict()

        for sensor_key, sensor in mixed_dataset['calibration_config']['sensors'].items():
            transform_key = generateKey(sensor['parent_link'], sensor['child_link'])
            source_child = sensor['child_link']
            evaluateTargetFrame(target_frame, source_child)
            errors = ErrorsDictionary(mixed_dataset,test_dataset,transform_key,target_frame,transform_key,source_child,errors_dict,compute_all_transformations)
            errors.saveToDict()
    
    # -------------------------------------------------------------
    # STEP 2: Print output table
    # -------------------------------------------------------------
    if errors_dict:
        table = Table(errors_dict, args, compute_all_transformations, target_frame, source_child)
        table.printTable()
    else:
        print(f'Nothing to output. Please verify the commands')
        

