#!/usr/bin/env python3

import os
import json
import argparse
import math
import numpy as np
import cv2
import math
import sys
from prettytable import PrettyTable
from colorama import Style, Fore

sys.path.insert(1, '../utils')
from links import joint_correspondence, joint_correspondence_human36m, links_evaluation, links_evaluation_human36m, \
    links_evaluation_human36m_mediapipe, joint_correspondence_human36m_mediapipe, links_mpi_inf_3dhp, joints_mpi_inf_3dhp
import matplotlib.pyplot as plt
from mpl_toolkits import mplot3d


def press(event):
    # print('press', event.key)
    if event.key == 'q' or event.key == 'Q':
        exit()
    # if event.key == 'a':
    #     result = sum(cnt)
    #     print(result, cnt)


def main():
    # --------------------------------------------------------
    # Arguments
    # --------------------------------------------------------
    ap = argparse.ArgumentParser()
    ap.add_argument("-dataset", "--dataset_name",
                    help="Dataset name. Datasets must be inside hpe/images/. Please use -dataset human36m to optimize Human 3.6M.",
                    type=str, default='mpi')
    ap.add_argument("-s", "--section",
                    help="Human 3.6M section to optimize. Only works when the dataset is set to human36m",
                    type=str, default='S1')
    ap.add_argument("-a", "--action",
                    help="Human 3.6M action to optimize. Only works when the dataset is set to human36m",
                    type=str, default='Directions-1')
    ap.add_argument("-seq", "--sequence",
                    help="MPI sequence optimize. Only works when the dataset is set to human36m",
                    type=str, default='Seq1')
    ap.add_argument("-si", "--show_images", help="Shows projection images", action="store_true",
                    default=False)
    ap.add_argument("-ext", "--extended_tables", help="Shows extended tables", action="store_true",
                    default=False)
    ap.add_argument("-2d_detector", "--2d_detector",
                    help="2D detector used for 2D human pose estimation. Current options: openpose (default), mediapipe",
                    type=str, default='groundtruth')
    args = vars(ap.parse_args())

    # --------------------------------------------------------
    # Loading files, and verifications configurations
    # --------------------------------------------------------

    dataset_folder = '../../images/' + args['dataset_name'] + '/'

    human36m = False
    mpi = False
    detector = 'groundtruth'
    if args['2d_detector'] == "mediapipe":
        detector = 'mediapipe'
    elif args['2d_detector'] == "openpose":
        detector = 'openpose'

    if args['dataset_name'] == "human36m":
        human36m = True
        dataset_folder = dataset_folder + 'processed/' + args['section'] + '/' + args['action'] + '/'
        joint_correspondence = joint_correspondence_human36m
        links_evaluation = links_evaluation_human36m

        if detector == "mediapipe":
            links_evaluation = links_evaluation_human36m_mediapipe
            joint_correspondence = joint_correspondence_human36m_mediapipe

    if args['dataset_name'] == "mpi":
        mpi = True
        dataset_folder = '../../images/mpi-inf-3dhp/' + args['section'] + '/' + args['sequence'] + '/'
        # if
        # joint_correspondence = joint_correspondence_human36m
        links_evaluation = links_mpi_inf_3dhp

    # check if dataset folder exists:
    if not os.path.exists(dataset_folder):
        raise Exception("The dataset folder does not exist!")

    ground_truth_path = dataset_folder + "ground_truth.txt"
    if os.path.exists(ground_truth_path):
        with open(ground_truth_path, "r") as fp:
            ground_truth = json.load(fp)
    else:
        raise Exception("The ground truth file does not exist!")

    if detector == 'openpose':
        poses_path = dataset_folder + "poses3d_op.json"
    elif detector == 'mediapipe':
        poses_path = dataset_folder + "poses3d_mp.json"
    else:
        poses_path = dataset_folder + "poses3d_gt.json"

    if os.path.exists(poses_path):
        with open(poses_path, "r") as fp:
            poses_3d = json.load(fp)
    else:
        raise Exception("The 3D poses file does not exist!")

    ground_truth_path = dataset_folder + "ground_truth.txt"
    if os.path.exists(ground_truth_path):
        with open(ground_truth_path, "r") as fp:
            ground_truth = json.load(fp)
    else:
        raise Exception("The ground truth file does not exist!")

    e = {}
    for frame_key, frame in poses_3d.items():
        e[frame_key] = {}
        if detector == 'groundtruth' and mpi == True:
            for joint_key, joint in joints_mpi_inf_3dhp.items():
                joint_detected = frame[joint_key]

                # if not joint_detected['valid']:  # skip invalid joints
                #     continue
                e[frame_key][joint_key] = {}

                x_det = joint_detected['X']/1000
                y_det = joint_detected['Y']/1000
                z_det = joint_detected['Z']/1000

                joint_gt = ground_truth[frame_key][joint_key]['pose']

                x_gt = joint_gt['x']/1000
                y_gt = joint_gt['y']/1000
                z_gt = joint_gt['z']/1000

                x_error = abs(x_det - x_gt)
                y_error = abs(y_det - y_gt)
                z_error = abs(z_det - z_gt)
                rmse = abs(math.dist([x_det, y_det, z_det], [x_gt, y_gt, z_gt]))
                e[frame_key][joint_key]['x_error'] = x_error
                e[frame_key][joint_key]['y_error'] = y_error
                e[frame_key][joint_key]['z_error'] = z_error
                e[frame_key][joint_key]['rmse'] = rmse

        else:
            for joint_key, joint in joint_correspondence.items():
                joint_key_detected = joint['body25']
                joint_detected = frame[joint_key_detected]

                # if not joint_detected['valid']:  # skip invalid joints
                #     continue
                e[frame_key][joint_key] = {}

                x_det = joint_detected['X']
                y_det = joint_detected['Y']
                z_det = joint_detected['Z']

                joint_key_gt = joint['ground_truth']
                joint_gt = ground_truth[frame_key][joint_key_gt]['pose']

                x_gt = joint_gt['x']
                y_gt = joint_gt['y']
                z_gt = joint_gt['z']

                x_error = abs(x_det - x_gt)
                y_error = abs(y_det - y_gt)
                z_error = abs(z_det - z_gt)
                rmse = abs(math.dist([x_det, y_det, z_det], [x_gt, y_gt, z_gt]))
                e[frame_key][joint_key]['x_error'] = x_error
                e[frame_key][joint_key]['y_error'] = y_error
                e[frame_key][joint_key]['z_error'] = z_error
                e[frame_key][joint_key]['rmse'] = rmse

    # print(e)

    # -------------------------------------------------------------
    # Print output table
    # -------------------------------------------------------------
    table_header = ['Frame #', 'Joint', 'RMS (m)', 'X err (m)', 'Y err (m)', 'Z err (m)']
    if not args['extended_tables']:
        table = PrettyTable(table_header)
        for frame_key, frame in poses_3d.items():
            x_avg = 0
            y_avg = 0
            z_avg = 0
            rmse_avg = 0
            total_joints = 0
            if detector == 'groundtruth' and mpi == True:
                for joint_key, joint in joints_mpi_inf_3dhp.items():
                    x_avg = x_avg + e[frame_key][joint_key]['x_error']
                    y_avg = y_avg + e[frame_key][joint_key]['y_error']
                    z_avg = z_avg + e[frame_key][joint_key]['z_error']
                    rmse_avg = rmse_avg + e[frame_key][joint_key]['rmse']
                    total_joints = total_joints + 1
            else:
                for joint_key, joint in joint_correspondence.items():
                    x_avg = x_avg + e[frame_key][joint_key]['x_error']
                    y_avg = y_avg + e[frame_key][joint_key]['y_error']
                    z_avg = z_avg + e[frame_key][joint_key]['z_error']
                    rmse_avg = rmse_avg + e[frame_key][joint_key]['rmse']
                    total_joints = total_joints + 1
            x_avg = x_avg / total_joints
            y_avg = y_avg / total_joints
            z_avg = z_avg / total_joints
            rmse_avg = rmse_avg / total_joints
            avg_row = [frame_key, Fore.BLUE + Style.BRIGHT + 'Average' + Style.RESET_ALL,
                       '%.4f' % rmse_avg,
                       '%.4f' % x_avg,
                       '%.4f' % y_avg,
                       '%.4f' % z_avg]
            table.add_row(avg_row)
        print(table)

    else:
        for frame_key, frame in poses_3d.items():
            table = PrettyTable(table_header)
            x_avg = 0
            y_avg = 0
            z_avg = 0
            rmse_avg = 0
            total_joints = 0
            if detector == 'groundtruth' and mpi == True:
                for joint_key, joint in joints_mpi_inf_3dhp.items():
                    row = [frame_key, joint_key,
                           '%.4f' % e[frame_key][joint_key]['rmse'],
                           '%.4f' % e[frame_key][joint_key]['x_error'],
                           '%.4f' % e[frame_key][joint_key]['y_error'],
                           '%.4f' % e[frame_key][joint_key]['z_error']]
                    table.add_row(row)
                    x_avg = x_avg + e[frame_key][joint_key]['x_error']
                    y_avg = y_avg + e[frame_key][joint_key]['y_error']
                    z_avg = z_avg + e[frame_key][joint_key]['z_error']
                    rmse_avg = rmse_avg + e[frame_key][joint_key]['rmse']
                    total_joints = total_joints + 1
            else:
                for joint_key, joint in joint_correspondence.items():
                    row = [frame_key, joint_key,
                           '%.4f' % e[frame_key][joint_key]['rmse'],
                           '%.4f' % e[frame_key][joint_key]['x_error'],
                           '%.4f' % e[frame_key][joint_key]['y_error'],
                           '%.4f' % e[frame_key][joint_key]['z_error']]
                    table.add_row(row)
                    x_avg = x_avg + e[frame_key][joint_key]['x_error']
                    y_avg = y_avg + e[frame_key][joint_key]['y_error']
                    z_avg = z_avg + e[frame_key][joint_key]['z_error']
                    rmse_avg = rmse_avg + e[frame_key][joint_key]['rmse']
                    total_joints = total_joints + 1
            x_avg = x_avg / total_joints
            y_avg = y_avg / total_joints
            z_avg = z_avg / total_joints
            rmse_avg = rmse_avg / total_joints
            table.add_row(['--------', '--------', '-----------', '--------', '--------', '--------'])
            avg_row = [frame_key, Fore.BLUE + Style.BRIGHT + 'Average' + Style.RESET_ALL,
                       '%.4f' % rmse_avg,
                       '%.4f' % x_avg,
                       '%.4f' % y_avg,
                       '%.4f' % z_avg]
            table.add_row(avg_row)

            for frame_key, frame in poses_3d.items():
                x_avg = 0
                y_avg = 0
                z_avg = 0
                rmse_avg = 0
                total_joints = 0
                if detector == 'groundtruth' and mpi == True:
                    for joint_key, joint in joints_mpi_inf_3dhp.items():
                        if joint_key in ['right_ankle', 'left_ankle', 'left_hip', 'right_hip', 'right_shoulder',
                                         'left_shoulder']:
                            continue
                        row = [frame_key, joint_key,
                               '%.4f' % e[frame_key][joint_key]['rmse'],
                               '%.4f' % e[frame_key][joint_key]['x_error'],
                               '%.4f' % e[frame_key][joint_key]['y_error'],
                               '%.4f' % e[frame_key][joint_key]['z_error']]
                        x_avg = x_avg + e[frame_key][joint_key]['x_error']
                        y_avg = y_avg + e[frame_key][joint_key]['y_error']
                        z_avg = z_avg + e[frame_key][joint_key]['z_error']
                        rmse_avg = rmse_avg + e[frame_key][joint_key]['rmse']
                        total_joints = total_joints + 1
                else:
                    for joint_key, joint in joint_correspondence.items():
                        if joint_key in ['right_ankle', 'left_ankle', 'left_hip', 'right_hip', 'right_shoulder',
                                         'left_shoulder']:
                            continue
                        row = [frame_key, joint_key,
                               '%.4f' % e[frame_key][joint_key]['rmse'],
                               '%.4f' % e[frame_key][joint_key]['x_error'],
                               '%.4f' % e[frame_key][joint_key]['y_error'],
                               '%.4f' % e[frame_key][joint_key]['z_error']]
                        x_avg = x_avg + e[frame_key][joint_key]['x_error']
                        y_avg = y_avg + e[frame_key][joint_key]['y_error']
                        z_avg = z_avg + e[frame_key][joint_key]['z_error']
                        rmse_avg = rmse_avg + e[frame_key][joint_key]['rmse']
                        total_joints = total_joints + 1
                x_avg = x_avg / total_joints
                y_avg = y_avg / total_joints
                z_avg = z_avg / total_joints
                rmse_avg = rmse_avg / total_joints
                table.add_row(['--------', '--------', '-----------', '--------', '--------', '--------'])
                avg_row = [frame_key, Fore.BLUE + Style.BRIGHT + 'Average Best Joints' + Style.RESET_ALL,
                           '%.4f' % rmse_avg,
                           '%.4f' % x_avg,
                           '%.4f' % y_avg,
                           '%.4f' % z_avg]
                table.add_row(avg_row)

            print(table)

    table = PrettyTable(table_header)

    if detector == 'groundtruth' and mpi == True:
        for joint_key, joint in joints_mpi_inf_3dhp.items():
            total_frames = 0
            x_avg = 0
            y_avg = 0
            z_avg = 0
            rmse_avg = 0
            for frame_key, frame in poses_3d.items():
                x_avg = x_avg + e[frame_key][joint_key]['x_error']
                y_avg = y_avg + e[frame_key][joint_key]['y_error']
                z_avg = z_avg + e[frame_key][joint_key]['y_error']
                rmse_avg = rmse_avg + e[frame_key][joint_key]['rmse']
                total_frames = total_frames + 1
            x_avg = x_avg / total_frames
            y_avg = y_avg / total_frames
            z_avg = z_avg / total_frames
            rmse_avg = rmse_avg / total_frames
            avg_row = [Fore.BLUE + Style.BRIGHT + 'Average' + Style.RESET_ALL, joint_key,
                       '%.4f' % rmse_avg,
                       '%.4f' % x_avg,
                       '%.4f' % y_avg,
                       '%.4f' % z_avg]
            table.add_row(avg_row)
    else:
        for joint_key, joint in joint_correspondence.items():
            total_frames = 0
            x_avg = 0
            y_avg = 0
            z_avg = 0
            rmse_avg = 0
            for frame_key, frame in poses_3d.items():
                x_avg = x_avg + e[frame_key][joint_key]['x_error']
                y_avg = y_avg + e[frame_key][joint_key]['y_error']
                z_avg = z_avg + e[frame_key][joint_key]['y_error']
                rmse_avg = rmse_avg + e[frame_key][joint_key]['rmse']
                total_frames = total_frames + 1
            x_avg = x_avg / total_frames
            y_avg = y_avg / total_frames
            z_avg = z_avg / total_frames
            rmse_avg = rmse_avg / total_frames
            avg_row = [Fore.BLUE + Style.BRIGHT + 'Average' + Style.RESET_ALL, joint_key,
                       '%.4f' % rmse_avg,
                       '%.4f' % x_avg,
                       '%.4f' % y_avg,
                       '%.4f' % z_avg]
            table.add_row(avg_row)
    print(table)

    # -------------------------------------------------------------
    # Visualization
    # -------------------------------------------------------------
    if args['show_images']:
        for frame_key, frame in poses_3d.items():
            fig = plt.figure(frame_key)
            fig.suptitle('Frame #' + str(frame_key), fontsize=14)
            fig.canvas.mpl_connect('key_press_event', press)
            ax = plt.axes(projection='3d')
            ax.set_xlabel('x')
            ax.set_ylabel('y')
            ax.set_zlabel('z')
            # ax.set_xlim3d(-1, 1)
            # ax.set_ylim3d(-1, 1)
            # ax.set_zlim3d(0, 2)
            X_vec = []
            Y_vec = []
            Z_vec = []
            joint_colors = []
            if detector == 'groundtruth' and mpi == True:
                for joint_key, joint in joints_mpi_inf_3dhp.items():
                    joint_detected = frame[joint_key]

                    x_det = joint_detected['X']
                    y_det = joint_detected['Y']
                    z_det = joint_detected['Z']

                    joint_gt = ground_truth[frame_key][joint_key]['pose']

                    x_gt = joint_gt['x']
                    y_gt = joint_gt['y']
                    z_gt = joint_gt['z']

                    X_vec.append(float(x_det))
                    Y_vec.append(float(y_det))
                    Z_vec.append(float(z_det))
                    b, g, r = tuple(c / 255 for c in joint_detected['color'])
                    joint_colors.append((r, g, b))

                    X_vec.append(float(x_gt))
                    Y_vec.append(float(y_gt))
                    Z_vec.append(float(z_gt))
                    b, g, r = (0, 0, 0)
                    joint_colors.append((r, g, b))
                    ax.plot([x_gt, x_det], [y_gt, y_det], [z_gt, z_det], c=(1, 1, 0))
            else:
                for joint_key, joint in joint_correspondence.items():
                    joint_detected = frame[joint_key]

                    x_det = joint_detected['X']
                    y_det = joint_detected['Y']
                    z_det = joint_detected['Z']

                    joint_gt = ground_truth[frame_key][joint_key]['pose']

                    x_gt = joint_gt['x']
                    y_gt = joint_gt['y']
                    z_gt = joint_gt['z']

                    X_vec.append(float(x_det))
                    Y_vec.append(float(y_det))
                    Z_vec.append(float(z_det))
                    b, g, r = tuple(c / 255 for c in joint_detected['color'])
                    joint_colors.append((r, g, b))



                    X_vec.append(float(x_gt))
                    Y_vec.append(float(y_gt))
                    Z_vec.append(float(z_gt))
                    b, g, r = (0, 0, 0)
                    joint_colors.append((r, g, b))
                    ax.plot([x_gt, x_det], [y_gt, y_det], [z_gt, z_det], c=(1, 1, 0))

            ax.scatter(X_vec, Y_vec, Z_vec, c=joint_colors)

            for link_name, link in links_evaluation.items():
                if detector == 'groundtruth' and mpi == True:
                    joint_key_child= link['child']
                    joint_key_parent = link['parent']
                    joint_child = frame[joint_key_child]
                    joint_parent = frame[joint_key_parent]
                else:
                    joint_key_child = link['child']
                    joint_key_parent = link['parent']
                    joint_key_child_bd25 = joint_correspondence[joint_key_child]['body25']
                    joint_key_parent_bd25 = joint_correspondence[joint_key_parent]['body25']
                    joint_child = frame[joint_key_child_bd25]
                    joint_parent = frame[joint_key_parent_bd25]
                X0 = joint_child['X']
                Y0 = joint_child['Y']
                Z0 = joint_child['Z']
                X1 = joint_parent['X']
                Y1 = joint_parent['Y']
                Z1 = joint_parent['Z']

                if human36m == False and (
                        link_name in ['middle_arms', 'middle_legs', 'arms_2_legs_right', 'arms_2_legs_left']):
                    plt.plot([X0, X1], [Y0, Y1], [Z0, Z1], c=(0.75, 0.75, 0.75))
                elif human36m == True and (link_name in ['spine']):
                    plt.plot([X0, X1], [Y0, Y1], [Z0, Z1], c=(0.75, 0.75, 0.75))
                else:
                    plt.plot([X0, X1], [Y0, Y1], [Z0, Z1], c=(0, 0, 0))

                if detector == 'groundtruth' and mpi == True:
                    # joint_key_child_gt = joints_mpi_inf_3dhp[joint_key_child]
                    # joint_key_parent_gt = joints_mpi_inf_3dhp[joint_key_parent]
                    joint_child_gt = ground_truth[frame_key][joint_key_child]['pose']
                    joint_parent_gt = ground_truth[frame_key][joint_key_parent]['pose']
                else:
                    joint_key_child_gt = joint_correspondence[joint_key_child]['ground_truth']
                    joint_key_parent_gt = joint_correspondence[joint_key_parent]['ground_truth']
                    joint_child_gt = ground_truth[frame_key][joint_key_child_gt]['pose']
                    joint_parent_gt = ground_truth[frame_key][joint_key_parent_gt]['pose']
                X0_gt = joint_child_gt['x']
                Y0_gt = joint_child_gt['y']
                Z0_gt = joint_child_gt['z']
                X1_gt = joint_parent_gt['x']
                Y1_gt = joint_parent_gt['y']
                Z1_gt = joint_parent_gt['z']

                if human36m == False and (
                        link_name in ['middle_arms', 'middle_legs', 'arms_2_legs_right', 'arms_2_legs_left']):
                    plt.plot([X0_gt, X1_gt], [Y0_gt, Y1_gt], [Z0_gt, Z1_gt], c=(0.75, 0.75, 0.75))
                elif human36m == True and (link_name in ['spine']):
                    plt.plot([X0_gt, X1_gt], [Y0_gt, Y1_gt], [Z0_gt, Z1_gt], c=(0.75, 0.75, 0.75))
                else:
                    plt.plot([X0_gt, X1_gt], [Y0_gt, Y1_gt], [Z0_gt, Z1_gt], c=(0, 0, 0))

            plt.show(block=False)
        plt.pause(1000)
        plt.waitforbuttonpress()

        # plt.pause(1)
    # if plt.waitforbuttonpress() == 'q':
    #     exit(0)


if __name__ == "__main__":
    main()
