#!/usr/bin/env python3
"""
Reads the calibration results from a json file and computes the evaluation metrics

Summary of approach to compute homography by using solvepnp

Problem definition: Two sensors: sensor source (ss) and sensor target (st), and one pattern p.
We can write the projection equations of the pattern corner coordinates (defined in the pattern's local coordinate frame) to the image of each sensor
[u_ss] = [fx   0  cx]   * [r11 r12 r13 tx] * [X_p]
[v_ss] = [ 0  fy  cy]     [r21 r22 r23 ty]   [Y_p]
[w_ss] = [ 0   0   1]_ss  [r31 r32 r33 tz]   [Z_p]
                          [  0   0   0  1]   [1  ]

in matricial notation, written for both sensors:

[u]_ss = K_ss * ss_T_p * [X]_p
[u]_st = K_st * st_T_p * [X]_p

ss_T_p can be retrieved using perspective n point (pnp), and st_T_p is retrieved from the transformation between both sensors st_T_ss, which was estimated by the calibration:

st_T_p = st_T_ss * ss_T_p

Because we are considering only points in the pattern's plane, then Z_p = 0, which leads to the removal of the last row and the third column from the transformation matrix T

[u_ss] = [fx   0  cx]   * [r11 r12  tx] * [X_p]
[v_ss] = [ 0  fy  cy]     [r21 r22  ty]   [Y_p]
[w_ss] = [ 0   0   1]_ss  [r31 r32  tz]   [  1]

and finally we can compute the homography matrix sa_H_p for both sensors

[u]_ss = ss_H_p * [X]_p
[u]_st = st_H_p * [X]_p

In both equation the pattern corners [X]_p are the same, which leads to:

[u]_st = st_H_p * (ss_H_p)^-1 * [u]_ss

Using this equation we can compute the projections from the detections in sa to the image of sb, and then
we compare the projections with the annotated corners in sb
"""

# Standard imports
import argparse
import os
import math
import json
from collections import OrderedDict
from copy import deepcopy

# ROS imports
import cv2
import numpy as np
from matplotlib import cm
from colorama import Style, Fore
from prettytable import PrettyTable

# Atom imports
from atom_core.atom import getTransform
from atom_core.dataset_io import getMixedDataset, loadResultsJSON, filterCollectionsFromDataset
from atom_core.drawing import drawCross2D, drawSquare2D
from atom_core.geometry import matrixToRodrigues, traslationRodriguesToTransform
from atom_core.utilities import rootMeanSquare, saveFileResults

# -------------------------------------------------------------------------------
# --- IMPORTS
# -------------------------------------------------------------------------------


# -------------------------------------------------------------------------------
# --- FUNCTIONS
# -------------------------------------------------------------------------------


def homographyFromTransform(T):
    H = np.zeros((3, 3), float)

    H[0, 0] = T[0, 0]
    H[0, 1] = T[0, 1]
    H[0, 2] = T[0, 3]

    H[1, 0] = T[1, 0]
    H[1, 1] = T[1, 1]
    H[1, 2] = T[1, 3]

    H[2, 0] = T[2, 0]
    H[2, 1] = T[2, 1]
    H[2, 2] = T[2, 3]

    return H


def undistortCorners(pts_in, K, D):
    """ Remove distortion from corner points. """

    # Assume points are represented as:
    # pt = [x1 x2 x3 ... xn]
    #      [y1 y2 y3 ... yn]
    #      [ 1  1  1 ...  1] ( optional)

    # remove homogeneous coordinate, and transpose since opencv needs the transposed notation
    points2 = cv2.undistortPoints(pts_in[0:2].T, K, D)

    fx, fy, cx, cy = K[0, 0], K[1, 1], K[0, 2], K[1, 2]

    undistorted_corners = np.ones((3, pts_in.shape[1]), np.float32)
    undistorted_corners[0, :] = points2[:, 0, 0] * fx + cx
    undistorted_corners[1, :] = points2[:, 0, 1] * fy + cy

    return undistorted_corners


def distortCorners(corners, K, D):
    # from https://docs.opencv.org/2.4/modules/calib3d/doc/camera_calibration_and_3d_reconstruction.html
    # where it says x'' = ... x'o

    fx, fy, cx, cy = K[0, 0], K[1, 1], K[0, 2], K[1, 2]
    k1, k2, p1, p2, k3 = D

    # # compute the homogeneous image coordinates (non pixels)
    xl = (corners[0, :] - cx) / fx
    yl = (corners[1, :] - cy) / fy

    # # apply undistortion
    r2 = xl ** 2 + yl ** 2  # r square (used multiple times bellow)
    xll = xl * (1 + k1 * r2 + k2 * r2 ** 2 + k3 * r2 ** 3) + 2 * p1 * xl * yl + p2 * (r2 + 2 * xl ** 2)
    yll = yl * (1 + k1 * r2 + k2 * r2 ** 2 + k3 * r2 ** 3) + p1 * (r2 + 2 * yl ** 2) + 2 * p2 * xl * yl

    distorted_corners = np.ones((3, corners.shape[1]), np.float32)
    distorted_corners[0, :] = xll * fx + cx
    distorted_corners[1, :] = yll * fy + cy

    return distorted_corners


# -------------------------------------------------------------------------------
# --- MAIN
# -------------------------------------------------------------------------------

if __name__ == "__main__":
    ap = argparse.ArgumentParser()
    ap.add_argument("-train_json", "--train_json_file", help="Json file containing train input dataset.", type=str,
                    required=True)
    ap.add_argument("-test_json", "--test_json_file", help="Json file containing test input dataset.", type=str,
                    required=True)
    ap.add_argument("-ss", "--sensor_source", help="Source transformation sensor.", type=str, required=True)
    ap.add_argument("-st", "--sensor_target", help="Target transformation sensor.", type=str, required=True)
    ap.add_argument("-si", "--show_images", help="If true the script shows images.", action='store_true', default=False)
    ap.add_argument("-csf", "--collection_selection_function", default=None, type=lambda s: eval(s, globals()),
                    help="A string to be evaluated into a lambda function that receives a collection name as input and "
                    "returns True or False to indicate if the collection should be loaded (and used in the "
                    "optimization). The Syntax is lambda name: f(x), where f(x) is the function in python "
                    "language. Example: lambda name: int(name) > 5 , to load only collections 6, 7, and onward.")
    ap.add_argument("-uic", "--use_incomplete_collections", action="store_true", default=False,
                    help="Remove any collection which does not have a detection for all sensors.", )
    ap.add_argument("-rpd", "--remove_partial_detections", help="Remove detected labels which are only partial."
                            "Used or the Charuco.", action="store_true", default=False)
    ap.add_argument("-pn", "--pattern_name",
                    help="Name of the pattern for which the evaluation will be performed", type=str, default='')

    # save results in a csv file
    ap.add_argument("-sfr", "--save_file_results", help="Store the results", action='store_true', default=False)
    ap.add_argument("-sfrn", "--save_file_results_name", help="Name of csv file to save the results. "
                    "Default: -test_json/results/{name_of_dataset}_{sensor_source}_to_{sensor_target}_results.csv", type=str, required=False)

    # parse args
    args = vars(ap.parse_known_args()[0])

    # ---------------------------------------
    # --- INITIALIZATION Read calibration data from files
    # ---------------------------------------
    # Loads the train json file containing the calibration results
    train_dataset, train_json_file = loadResultsJSON(args["train_json_file"], args["collection_selection_function"])

    # Loads the test json file containing a set of collections to evaluate the calibration
    test_dataset, test_json_file = loadResultsJSON(args["test_json_file"], args["collection_selection_function"])

    # ---------------------------------------
    # --- Filter some collections and / or sensors from the dataset
    # ---------------------------------------
    test_dataset = filterCollectionsFromDataset(test_dataset, args)  # filter collections

    # ---------------------------------------
    # --- Get mixed json (calibrated transforms from train and the rest from test)
    # ---------------------------------------
    original_mixed_dataset = getMixedDataset(train_dataset, test_dataset)

    # ---------------------------------------
    # --- Get intrinsic data for both sensors
    # ---------------------------------------
    # Source sensor
    K_s = np.zeros((3, 3), np.float32)
    D_s = np.zeros((5, 1), np.float32)
    K_s[0, :] = original_mixed_dataset['sensors'][args['sensor_source']]['camera_info']['K'][0:3]
    K_s[1, :] = original_mixed_dataset['sensors'][args['sensor_source']]['camera_info']['K'][3:6]
    K_s[2, :] = original_mixed_dataset['sensors'][args['sensor_source']]['camera_info']['K'][6:9]
    D_s[:, 0] = original_mixed_dataset['sensors'][args['sensor_source']]['camera_info']['D'][0:5]

    # Target sensor
    K_t = np.zeros((3, 3), np.float32)
    D_t = np.zeros((5, 1), np.float32)
    K_t[0, :] = original_mixed_dataset['sensors'][args['sensor_target']]['camera_info']['K'][0:3]
    K_t[1, :] = original_mixed_dataset['sensors'][args['sensor_target']]['camera_info']['K'][3:6]
    K_t[2, :] = original_mixed_dataset['sensors'][args['sensor_target']]['camera_info']['K'][6:9]
    D_t[:, 0] = original_mixed_dataset['sensors'][args['sensor_target']]['camera_info']['D'][0:5]

    # Defining frames
    ss_frame = original_mixed_dataset['calibration_config']['sensors'][args['sensor_source']]['link']
    st_frame = original_mixed_dataset['calibration_config']['sensors'][args['sensor_target']]['link']

    # Patterns to evaluate
    if args['pattern_name'] == '':
        patterns_to_evaluate = original_mixed_dataset['calibration_config']['calibration_patterns'].keys()
    else:
        patterns_to_evaluate = [args['pattern_name']]

    for pattern_key in patterns_to_evaluate:
        mixed_dataset = deepcopy(original_mixed_dataset)
        # Deleting collections where the pattern is not found by all sensors:
        collections_to_delete = []
        for collection_key, collection in mixed_dataset['collections'].items():
            for sensor_key, sensor in mixed_dataset['sensors'].items():
                if not collection['labels'][pattern_key][sensor_key]['detected'] and (
                        sensor_key == args['sensor_source'] or sensor_key == args['sensor_target']):
                    print(
                        Fore.RED + "Removing collection " + collection_key + ' -> pattern was not found in sensor ' +
                        sensor_key + ' (must be found in all sensors).' + Style.RESET_ALL)

                    collections_to_delete.append(collection_key)
                    break

        for collection_key in collections_to_delete:
            del mixed_dataset['collections'][collection_key]

        # Reprojection error graphics definitions
        colors = cm.tab20b(np.linspace(0, 1, len(mixed_dataset['collections'].items())))

        e = {}  # dictionary with all the errors
        od = OrderedDict(sorted(mixed_dataset['collections'].items(), key=lambda t: int(t[0])))

        for collection_key, collection in od.items():
            e[collection_key] = {}  # init the dictionary of errors for this collection

            # Read image data
            path_s = os.path.dirname(test_json_file) + '/' + collection['data'][args['sensor_source']]['data_file']
            path_t = os.path.dirname(test_json_file) + '/' + collection['data'][args['sensor_target']]['data_file']

            image_s = cv2.imread(path_s)
            gray_s = cv2.cvtColor(image_s, cv2.COLOR_BGR2GRAY)
            image_t = cv2.imread(path_t)
            gray_t = cv2.cvtColor(image_t, cv2.COLOR_BGR2GRAY)

            # Get pattern number of corners
            nx = mixed_dataset['calibration_config']['calibration_patterns'][pattern_key]['dimension']['x']
            ny = mixed_dataset['calibration_config']['calibration_patterns'][pattern_key]['dimension']['y']
            square = mixed_dataset['calibration_config']['calibration_patterns'][pattern_key]['size']

            # Get corners and idxs for the source sensor
            corners_s = np.ones((3, len(collection['labels'][pattern_key][args['sensor_source']]['idxs'])), dtype=float)
            idxs_s = list(range(0, len(collection['labels'][pattern_key][args['sensor_source']]['idxs'])))
            for idx, point in enumerate(collection['labels'][pattern_key][args['sensor_source']]['idxs']):
                corners_s[0, idx] = point['x']
                corners_s[1, idx] = point['y']
                idxs_s[idx] = point['id']

            # Get corners and idxs for the target sensor
            corners_t = np.ones((3, len(collection['labels'][pattern_key][args['sensor_target']]['idxs'])), dtype=float)
            idxs_t = list(range(0, len(collection['labels'][pattern_key][args['sensor_target']]['idxs'])))
            for idx, point in enumerate(collection['labels'][pattern_key][args['sensor_target']]['idxs']):
                corners_t[0, idx] = point['x']
                corners_t[1, idx] = point['y']
                idxs_t[idx] = point['id']

            # for each labeled point in the source image, project in to the target image and measure the distance to the
            # detection in the target image

            # -------------------------------------------------------------
            # STEP 1: Define corner coordinates in the pattern's local coordinate frame.
            # -------------------------------------------------------------
            objp = np.zeros((nx * ny, 4), float)
            objp[:, :2] = square * np.mgrid[0:nx, 0:ny].T.reshape(-1, 2)
            objp[:, 3] = 1

            # -------------------------------------------------------------
            # STEP 2. Compute sa_T_p
            # -------------------------------------------------------------
            _, rvecs, tvecs = cv2.solvePnP(objp.T[: 3, :].T[idxs_s], np.array(corners_s[0: 2, :].T, dtype=np.float32),
                                           K_s, D_s)
            ss_T_p = traslationRodriguesToTransform(tvecs, rvecs)
            # print('ss_T_p =\n' + str(ss_T_p))

            # STEP 3. Compute sb_T_p
            ss_frame = mixed_dataset['calibration_config']['sensors'][args['sensor_source']]['link']
            st_frame = mixed_dataset['calibration_config']['sensors'][args['sensor_target']]['link']
            st_T_ss = getTransform(st_frame, ss_frame, mixed_dataset['collections'][collection_key]['transforms'])

            st_T_p = np.dot(st_T_ss, ss_T_p)
            # print('st_T_p =\n' + str(st_T_p))

            # -------------------------------------------------------------
            # STEP 4: Compute homography matrices for both sensors and the combined homography
            # -------------------------------------------------------------
            ss_H_p = np.dot(K_s, homographyFromTransform(ss_T_p))
            st_H_p = np.dot(K_t, homographyFromTransform(st_T_p))
            st_H_ss = np.dot(st_H_p, np.linalg.inv(ss_H_p))  # combined homography
            # print('ss_H_p =\n' + str(ss_H_p))
            # print('st_H_p =\n' + str(st_H_p))
            # print('st_H_ss =\n' + str(st_H_ss))

            # -------------------------------------------------------------
            # STEP 5: Remove distortion from source sensor corners
            # -------------------------------------------------------------
            ucorners_s = undistortCorners(corners_s, K_s, D_s)

            # -------------------------------------------------------------
            # STEP 6: Project from source sensor to target sensor
            # -------------------------------------------------------------
            ucorners_s_proj_to_t = np.dot(st_H_ss, ucorners_s)
            # Normalize pixel coordinates to have w = 1
            ucorners_s_proj_to_t = ucorners_s_proj_to_t / np.tile(ucorners_s_proj_to_t[2, :], (3, 1))

            # -------------------------------------------------------------
            # STEP 7: adding distortion to the projections.
            # -------------------------------------------------------------
            # print(ucorners_s_proj_to_t)
            corners_s_proj_to_t = distortCorners(ucorners_s_proj_to_t, K_t, D_t)
            # print(corners_s_proj_to_t)

            # -------------------------------------------------------------
            # STEP 8: Compute the error whenever a detection of a given corner is successful in the source and also the target sensors.
            # -------------------------------------------------------------
            x_errors = []
            y_errors = []
            rms_errors = []
            for label_s_idx, label_s in enumerate(collection['labels'][pattern_key][args['sensor_source']]['idxs']):
                has_corresponding_id = False
                # Search for corresponding label id in target sensor?
                for label_t_idx, label_t in enumerate(collection['labels'][pattern_key][args['sensor_target']]['idxs']):
                    if label_s['id'] == label_t['id']:
                        has_corresponding_id = True
                        break

                if has_corresponding_id:
                    x_t, y_t = corners_t[0, label_t_idx], corners_t[1, label_t_idx]
                    x_s_proj_to_t, y_s_proj_to_t = corners_s_proj_to_t[0,
                                                                       label_s_idx], corners_s_proj_to_t[1, label_s_idx]

                    x_errors.append(abs(x_t - x_s_proj_to_t))
                    y_errors.append(abs(y_t - y_s_proj_to_t))
                    rms_errors.append(math.sqrt((x_t - x_s_proj_to_t)**2 + (y_t - y_s_proj_to_t)**2))
                    continue

            e[collection_key]['x'] = np.average(x_errors)
            e[collection_key]['y'] = np.average(y_errors)
            e[collection_key]['rms'] = rootMeanSquare(rms_errors)

            # -------------------------------------------------------------
            # STEP 9: Compute translation and rotation errors (This is from Eurico, did not change style)
            # -------------------------------------------------------------
            delta_total = []
            terr = []
            rerr = []
            common_frame = mixed_dataset['calibration_config']['world_link']
            source_frame = mixed_dataset['calibration_config']['sensors'][args['sensor_source']]['link']
            target_frame = mixed_dataset['calibration_config']['sensors'][args['sensor_target']]['link']

            _, rvecs, tvecs = cv2.solvePnP(objp.T[: 3, :].T[idxs_t],
                                           np.array(corners_t[0: 2, :].T, dtype=np.float32),
                                           K_t, D_t)
            pattern_pose_target = traslationRodriguesToTransform(tvecs, rvecs)

            bTp = getTransform(common_frame, target_frame,
                               mixed_dataset['collections'][collection_key]['transforms'])

            pattern_pose_target = np.dot(bTp, pattern_pose_target)

            ret, rvecs, tvecs = cv2.solvePnP(objp.T[: 3, :].T[idxs_s], np.array(corners_s[0: 2, :].T, dtype=np.float32),
                                             K_s, D_s)
            pattern_pose_source = traslationRodriguesToTransform(tvecs, rvecs)

            bTp = getTransform(common_frame, source_frame,
                               mixed_dataset['collections'][collection_key]['transforms'])

            pattern_pose_source = np.dot(bTp, pattern_pose_source)

            delta = np.dot(np.linalg.inv(pattern_pose_source), pattern_pose_target)

            deltaT = delta[0:3, 3]
            deltaR = matrixToRodrigues(delta[0:3, 0:3])

            e[collection_key]['trans'] = np.linalg.norm(deltaT) * 1000
            e[collection_key]['rot'] = np.linalg.norm(deltaR) * 180.0 / np.pi

            # -------------------------------------------------------------
            # STEP 10: Show projections (optional)
            # -------------------------------------------------------------
            if args['show_images']:
                width = collection['data'][args['sensor_target']]['width']
                height = collection['data'][args['sensor_target']]['height']
                diagonal = math.sqrt(width ** 2 + height ** 2)

                window_name_s = 'Sensor ' + args['sensor_source'] + ' (source) - Collection ' + collection_key
                window_name_t = 'Sensor ' + args['sensor_target'] + ' (target) - Collection ' + collection_key
                cv2.namedWindow(window_name_s, cv2.WINDOW_NORMAL)
                cv2.namedWindow(window_name_t, cv2.WINDOW_NORMAL)
                image_gui_s = deepcopy(image_s)
                image_gui_t = deepcopy(image_t)
                cmap = cm.gist_rainbow(np.linspace(0, 1, nx * ny))

                # Iterate all corner detections in the source image and, if the same corner was detected on the target image, draw in color.
                for label_s_idx, label_s in enumerate(collection['labels'][args['sensor_source']]['idxs']):
                    has_corresponding_id = False
                    # Search for corresponding label id in target sensor?
                    for label_t_idx, label_t in enumerate(collection['labels'][args['sensor_target']]['idxs']):
                        if label_s['id'] == label_t['id']:
                            has_corresponding_id = True
                            break

                    x_t, y_t = corners_t[0, label_t_idx], corners_t[1, label_t_idx]
                    x_s, y_s = corners_s[0, label_s_idx], corners_s[1, label_s_idx]
                    x_s_proj_to_t, y_s_proj_to_t = corners_s_proj_to_t[0,
                                                                       label_s_idx], corners_s_proj_to_t[1, label_s_idx]

                    if has_corresponding_id:
                        color = (cmap[label_s['id'], 2] * 255, cmap[label_s['id'], 1]
                                 * 255, cmap[label_s['id'], 0] * 255)

                        # Draw labels on source image (crosses)
                        drawCross2D(image_gui_s, x_s, y_s, 5, color=color, thickness=1)

                        # Draw labels on target image (squares)
                        drawSquare2D(image_gui_t, x_t, y_t, 6, color=color, thickness=1)

                        # Draw projections of source to target, i.e. proj_to_t (crosses)
                        drawCross2D(image_gui_t, x_s_proj_to_t, y_s_proj_to_t, 5, color=color, thickness=1)

                    else:
                        color = (140, 140, 140)
                        # Draw labels on source image (gray crosses)
                        drawCross2D(image_gui_s, x_s, y_s, 5, color=color, thickness=1)

                cv2.resizeWindow(window_name_s, 800, int(800/image_gui_s.shape[1] * image_gui_s.shape[0]))
                cv2.resizeWindow(window_name_t, 800, int(800/image_gui_t.shape[1] * image_gui_t.shape[0]))
                cv2.imshow(window_name_s, image_gui_s)
                cv2.imshow(window_name_t, image_gui_t)

                print('Errors for collection ' + collection_key + ':\n' + str(e[collection_key]))

                key = cv2.waitKey(0)
                cv2.destroyWindow(window_name_s)
                cv2.destroyWindow(window_name_t)

                if key == ord('c') or key == ord('q'):
                    print('q pressed. Continuing ...')
                    args['show_images'] = False

        # -------------------------------------------------------------
        # STEP 11: Print output table
        # -------------------------------------------------------------
        table_header = ['Collection #', 'RMS (pix)', 'X err (pix)', 'Y err (pix)', 'Trans (mm)', 'Rot (deg)']
        table = PrettyTable(table_header)
        # table to save. This table was created, because the original has colors and the output csv save them as random characters
        table_to_save = PrettyTable(table_header)

        od = OrderedDict(sorted(mixed_dataset['collections'].items(), key=lambda t: int(t[0])))
        for collection_key, collection in od.items():
            row = [collection_key,
                   '%.4f' % e[collection_key]['rms'],
                   '%.4f' % e[collection_key]['x'],
                   '%.4f' % e[collection_key]['y'],
                   '%.4f' % e[collection_key]['trans'],
                   '%.4f' % e[collection_key]['rot']]

            table.add_row(row)
            table_to_save.add_row(row)

        # Compute averages and add a bottom row
        bottom_row = []  # Compute averages and add bottom row to table
        bottom_row_save = []
        for col_idx, _ in enumerate(table_header):
            if col_idx == 0:
                bottom_row.append(Fore.BLUE + Style.BRIGHT + 'Averages' + Style.RESET_ALL)
                bottom_row_save.append('Averages')
                continue

            total = 0
            count = 0
            for row in table.rows:
                # if row[col_idx].isnumeric():
                try:
                    value = float(row[col_idx])
                    total += float(value)
                    count += 1
                except:
                    pass

            value = '%.4f' % (total / count)
            bottom_row.append(Fore.BLUE + value + Style.RESET_ALL)
            bottom_row_save.append(value)

        table.add_row(bottom_row)
        table_to_save.add_row(bottom_row_save)

        # Put larger errors in red per column (per sensor)
        for col_idx, _ in enumerate(table_header):
            if col_idx == 0:  # nothing to do
                continue

            max = 0
            max_row_idx = 0
            for row_idx, row in enumerate(table.rows[:-1]):  # ignore bottom row
                try:
                    value = float(row[col_idx])
                except:
                    continue

                if value > max:
                    max = value
                    max_row_idx = row_idx

            # set the max column value to red
            table.rows[max_row_idx][col_idx] = Fore.RED + table.rows[max_row_idx][col_idx] + Style.RESET_ALL

        table.align = 'c'
        table_to_save.align = 'c'
        print(Style.BRIGHT + 'Errors per collection' + Style.RESET_ALL)
        print(table)

        # save results in csv file
        if args['save_file_results']:
            if args['save_file_results_name'] is None:
                results_name = f'{args["sensor_source"]}_to_{args["sensor_target"]}_results.csv'
                saveFileResults(args['train_json_file'], args['test_json_file'], results_name, table_to_save)
            else:
                with open(args['save_file_results_name'], 'w', newline='') as f_output:
                    f_output.write(table_to_save.get_csv_string())

    print('Ending script...')
    exit()
