#!/usr/bin/env python3

import os
import json
import argparse
import math
from copy import deepcopy

import cv2
import numpy as np
import math
import sys
from prettytable import PrettyTable
from colorama import Style, Fore

sys.path.insert(1, '../utils')
from links import joint_correspondence_human36m, joint_correspondence, joint_correspondence_human36m_mediapipe, links_mpi_inf_3dhp, joints_mpi_inf_3dhp, joint_correspondence_mpi_mediapipe
from projection import projectToCamera
from draw import drawSquare2D, drawCross2D


def main():
    # --------------------------------------------------------
    # Arguments
    # --------------------------------------------------------

    ap = argparse.ArgumentParser()
    ap.add_argument("-dataset", "--dataset_name",
                    help="Dataset name. Datasets must be inside hpe/images/. Please use -dataset human36m to optimize Human 3.6M.",
                    type=str, default='mpi')
    ap.add_argument("-s", "--section",
                    help="Human 3.6M section to optimize. Only works when the dataset is set to human36m",
                    type=str, default='S1')
    ap.add_argument("-a", "--action",
                    help="Human 3.6M action to optimize. Only works when the dataset is set to human36m",
                    type=str, default='Directions-1')
    ap.add_argument("-seq", "--sequence",
                    help="MPI sequence optimize. Only works when the dataset is set to human36m",
                    type=str, default='Seq1')
    ap.add_argument("-si", "--show_images", help="Shows projection images", action="store_true",
                    default=False)
    ap.add_argument("-ext", "--extended_tables", help="Shows extended tables", action="store_true",
                    default=False)
    ap.add_argument("-2d_detector", "--2d_detector",
                    help="2D detector used for 2D human pose estimation. Current options: openpose (default), mediapipe",
                    type=str, default='mediapipe')
    args = vars(ap.parse_args())

    # --------------------------------------------------------
    # Loading files, and verifications configurations
    # --------------------------------------------------------
    world_frame = 'world'

    dataset_folder = '../../images/' + args['dataset_name'] + '/'

    human36m = False
    mpi = False
    detector = 'openpose'
    if args['2d_detector'] == "mediapipe":
        detector = 'mediapipe'

    if args['dataset_name'] == "human36m":
        human36m = True
        dataset_folder = '../../images/human36m/processed/' + args['section'] + '/' + args['action'] + '/'
        joint_correspondence = joint_correspondence_human36m
        if detector == "mediapipe":
            joint_correspondence = joint_correspondence_human36m_mediapipe
    elif args['dataset_name'] == "mpi":
        mpi = True
        dataset_folder = '../../images/mpi-inf-3dhp/' + args['section'] + '/' + args['sequence'] + '/'
        links_evaluation = links_mpi_inf_3dhp
        if detector == "mediapipe":
            joint_correspondence = joint_correspondence_mpi_mediapipe


    # check if dataset folder exists:
    if not os.path.exists(dataset_folder):
        raise Exception("The dataset folder does not exist!")

    ground_truth_path = dataset_folder + "ground_truth.txt"
    if os.path.exists(ground_truth_path):
        with open(ground_truth_path, "r") as fp:
            ground_truth = json.load(fp)
    else:
        raise Exception("The ground truth file does not exist!")

    if detector == 'openpose':
        camera_dict_path = dataset_folder + "cameras_op.json"
    elif detector == 'mediapipe':
        camera_dict_path = dataset_folder + "cameras_mp.json"

    if os.path.exists(camera_dict_path):
        with open(camera_dict_path, "r") as fp:
            cameras = json.load(fp)
    else:
        raise Exception("The camera dictionary file does not exist!")

    ground_truth_path = dataset_folder + "ground_truth.txt"
    if os.path.exists(ground_truth_path):
        with open(ground_truth_path, "r") as fp:
            ground_truth = json.load(fp)
    else:
        raise Exception("The ground truth file does not exist!")

    # --------------------------------------------------------
    # Calculate projections
    # --------------------------------------------------------
    e = {}
    for camera_key, camera in cameras.items():
        e[camera_key] = {}
        for frame_key, frame in camera['frames'].items():
            window_name = camera_key + '_' + frame_key
            img = cv2.imread(frame['image_path'])
            # img = np.uint8(np.array(frame['image']))
            e[camera_key][frame_key] = {}
            for joint_key, joint in joint_correspondence.items():
                # print(frame)

                joint_key_detected = joint['body25']
                joint_detected = frame['joints'][joint_key_detected]

                if not joint_detected['valid']:  # skip invalid joints
                    continue
                e[camera_key][frame_key][joint_key] = {}
                joint_key_gt = joint['ground_truth']
                joint_gt = ground_truth[frame_key][joint_key_gt]['pose']

                # get 3d point coordinates
                pts_in_world = np.ndarray((4, 1), dtype=float)
                pts_in_world[0][0] = ground_truth[frame_key][joint_key_gt]['pose']['x']
                pts_in_world[1][0] = ground_truth[frame_key][joint_key_gt]['pose']['y']
                pts_in_world[2][0] = ground_truth[frame_key][joint_key_gt]['pose']['z']
                pts_in_world[3][0] = 1

                # Transform to the camera's coordinate frame
                pts_in_sensor = np.dot(camera['extrinsics'], pts_in_world)

                # Project 3D point to camera
                pts_in_image, _, _ = projectToCamera(np.array(camera['intrinsics']), np.array(camera['distortion']),
                                                     camera['width'],
                                                     camera['height'], pts_in_sensor[0:3, :])

                # Compute distance from 2D detection of joint and 3D->2D projection
                xpix_projected = pts_in_image[0][0]
                ypix_projected = pts_in_image[1][0]

                xpix_detected = joint_detected['x']
                ypix_detected = joint_detected['y']

                color = joint_detected['color']

                drawSquare2D(img, xpix_projected, ypix_projected, 10, color=color, thickness=3)
                drawCross2D(img, xpix_detected, ypix_detected, 10, color=color, thickness=3)
                # exit(0)
                x_error = abs(xpix_detected - xpix_projected)
                y_error = abs(ypix_detected - ypix_projected)
                rmse = abs(math.dist([xpix_detected, ypix_detected], [xpix_projected, ypix_projected]))
                e[camera_key][frame_key][joint_key]['x_error'] = x_error
                e[camera_key][frame_key][joint_key]['y_error'] = y_error
                e[camera_key][frame_key][joint_key]['rmse'] = rmse

                if args['show_images']:
                    if detector == 'mediapipe':
                        img_gui= cv2.resize(img, None, fx = 0.5, fy = 0.5)
                        cv2.imshow(window_name,img_gui)
                    else:
                        cv2.imshow(window_name, img)
            if args['show_images']:
                cv2.waitKey(0)

    # print(e)
    # -------------------------------------------------------------
    # Print output table
    # -------------------------------------------------------------
    table_header = ['Camera', 'Frame #', 'Joint', 'RMS (pix)', 'X err (pix)', 'Y err (pix)']
    # table = PrettyTable(table_header)
    if not args['extended_tables']:
        table = PrettyTable(table_header)
        for camera_key, camera in cameras.items():
            for frame_key, frame in camera['frames'].items():
                x_avg = 0
                y_avg = 0
                rmse_avg = 0
                total_joints = 0
                for joint_key, joint in joint_correspondence.items():
                    joint_key_detected = joint['body25']
                    joint_detected = frame['joints'][joint_key_detected]
                    if not joint_detected['valid']:  # skip invalid joints
                        continue
                    x_avg = x_avg + e[camera_key][frame_key][joint_key]['x_error']
                    y_avg = y_avg + e[camera_key][frame_key][joint_key]['y_error']
                    rmse_avg = rmse_avg + e[camera_key][frame_key][joint_key]['rmse']
                    total_joints = total_joints + 1
                x_avg = x_avg / total_joints
                y_avg = y_avg / total_joints
                rmse_avg = rmse_avg / total_joints
                avg_row = [camera_key, frame_key, Fore.BLUE + Style.BRIGHT + 'Average' + Fore.BLACK + Style.NORMAL,
                           '%.3f' % rmse_avg,
                           '%.3f' % x_avg,
                           '%.3f' % y_avg]
                table.add_row(avg_row)
        print(table)

    else:
        for camera_key, camera in cameras.items():
            table = PrettyTable(table_header)
            for frame_key, frame in camera['frames'].items():
                x_avg = 0
                y_avg = 0
                rmse_avg = 0
                total_joints = 0
                for joint_key, joint in joint_correspondence.items():
                    joint_key_detected = joint['body25']
                    joint_detected = frame['joints'][joint_key_detected]
                    if not joint_detected['valid']:  # skip invalid joints
                        continue
                        print(e[camera_key][frame_key][joint_key])
                    row = [camera_key, frame_key, joint_key,
                           '%.3f' % e[camera_key][frame_key][joint_key]['rmse'],
                           '%.3f' % e[camera_key][frame_key][joint_key]['x_error'],
                           '%.3f' % e[camera_key][frame_key][joint_key]['y_error']]
                    table.add_row(row)
                    x_avg = x_avg + e[camera_key][frame_key][joint_key]['x_error']
                    y_avg = y_avg + e[camera_key][frame_key][joint_key]['y_error']
                    rmse_avg = rmse_avg + e[camera_key][frame_key][joint_key]['rmse']
                    total_joints = total_joints + 1
                x_avg = x_avg / total_joints
                y_avg = y_avg / total_joints
                rmse_avg = rmse_avg / total_joints
                table.add_row(['--------', '--------', '-----------', '--------', '--------', '--------'])
                avg_row = [camera_key, frame_key, Fore.BLUE + Style.BRIGHT + 'Average' + Fore.BLACK + Style.NORMAL,
                           '%.3f' % rmse_avg,
                           '%.3f' % x_avg,
                           '%.3f' % y_avg]
                table.add_row(avg_row)
                table.add_row(['--------', '--------', '-----------', '--------', '--------', '--------'])
            print(table)

    table = PrettyTable(table_header)
    x_avg = 0
    y_avg = 0
    rmse_avg = 0
    for frame_key, frame in camera['frames'].items():
        total_frames = 0
        x_avg = 0
        y_avg = 0
        rmse_avg = 0
        for joint_key, joint in joint_correspondence.items():
            joint_key_detected = joint['body25']
            joint_detected = frame['joints'][joint_key_detected]
            if not joint_detected['valid']:  # skip invalid joints
                continue
            x_avg = x_avg + e[camera_key][frame_key][joint_key]['x_error']
            y_avg = y_avg + e[camera_key][frame_key][joint_key]['y_error']
            rmse_avg = rmse_avg + e[camera_key][frame_key][joint_key]['rmse']
            total_frames = total_frames + 1
        x_avg = x_avg / total_frames
        y_avg = y_avg / total_frames
        rmse_avg = rmse_avg / total_frames
        avg_row = [camera_key, Fore.BLUE + Style.BRIGHT + frame_key + Fore.BLACK + Style.NORMAL, 'Average',
                   '%.3f' % rmse_avg,
                   '%.3f' % x_avg,
                   '%.3f' % y_avg]
        table.add_row(avg_row)
    print(table)

    table = PrettyTable(table_header)
    for joint_key, joint in joint_correspondence.items():
        total_frames = 0
        x_avg = 0
        y_avg = 0
        rmse_avg = 0
        for frame_key, frame in camera['frames'].items():
            joint_key_detected = joint['body25']
            joint_detected = frame['joints'][joint_key_detected]
            if not joint_detected['valid']:  # skip invalid joints
                continue
            x_avg = x_avg + e[camera_key][frame_key][joint_key]['x_error']
            y_avg = y_avg + e[camera_key][frame_key][joint_key]['y_error']
            rmse_avg = rmse_avg + e[camera_key][frame_key][joint_key]['rmse']
            total_frames = total_frames + 1
        x_avg = x_avg / total_frames
        y_avg = y_avg / total_frames
        rmse_avg = rmse_avg / total_frames
        avg_row = [camera_key, Fore.BLUE + Style.BRIGHT + 'Average' + Fore.BLACK + Style.NORMAL, joint_key,
                   '%.3f' % rmse_avg,
                   '%.3f' % x_avg,
                   '%.3f' % y_avg]
        table.add_row(avg_row)
    print(table)


if __name__ == "__main__":
    main()
