#!/usr/bin/env python3

import os
import json
import argparse
import math
from copy import deepcopy

import cv2
import numpy as np
import math
import sys
from prettytable import PrettyTable
from colorama import Style, Fore

sys.path.insert(1, '../utils')
from links import joint_correspondence_human36m, joint_correspondence, joint_correspondence_human36m_mediapipe, \
    links_mpi_inf_3dhp, joints_mpi_inf_3dhp, joint_correspondence_mpi_mediapipe
from projection import projectToCamera
from draw import drawSquare2D, drawCross2D


def main():
    # --------------------------------------------------------
    # Arguments
    # --------------------------------------------------------

    ap = argparse.ArgumentParser()
    ap.add_argument("-dataset", "--dataset_name",
                    help="Dataset name. Datasets must be inside hpe/images/. Please use -dataset human36m to optimize Human 3.6M.",
                    type=str, default='mpi')
    ap.add_argument("-s", "--section",
                    help="Human 3.6M section to optimize. Only works when the dataset is set to human36m",
                    type=str, default='S1')
    ap.add_argument("-a", "--action",
                    help="Human 3.6M action to optimize. Only works when the dataset is set to human36m",
                    type=str, default='Directions-1')
    ap.add_argument("-seq", "--sequence",
                    help="MPI sequence optimize. Only works when the dataset is set to human36m",
                    type=str, default='Seq1')
    ap.add_argument("-si", "--show_images", help="Shows projection images", action="store_true",
                    default=False)
    ap.add_argument("-ext", "--extended_tables", help="Shows extended tables", action="store_true",
                    default=False)
    ap.add_argument("-2d_detector", "--2d_detector",
                    help="2D detector used for 2D human pose estimation. Current options: openpose (default), mediapipe",
                    type=str, default='groundtruth')
    ap.add_argument("-cams", "--cameras_to_use", help="Cameras to use in optimization", nargs='+',
                    default=['camera_0', 'camera_4', 'camera_5', 'camera_8'])
    args = vars(ap.parse_args())

    # --------------------------------------------------------
    # Loading files, and verifications configurations
    # --------------------------------------------------------
    world_frame = 'world'

    dataset_folder = '../../images/' + args['dataset_name'] + '/'

    human36m = False
    mpi = False
    detector = 'groundtruth'
    if args['2d_detector'] == "mediapipe":
        detector = 'mediapipe'

    if args['dataset_name'] == "human36m":
        human36m = True
        dataset_folder = '../../images/human36m/processed/' + args['section'] + '/' + args['action'] + '/'
        joint_correspondence = joint_correspondence_human36m
        if detector == "mediapipe":
            joint_correspondence = joint_correspondence_human36m_mediapipe


    elif args['dataset_name'] == "mpi":
        mpi = True
        dataset_folder = '../../images/mpi-inf-3dhp/' + args['section'] + '/' + args['sequence'] + '/'
        links_evaluation = links_mpi_inf_3dhp
        if detector == "mediapipe":
            joint_correspondence = joint_correspondence_mpi_mediapipe

    # check if dataset folder exists:
    if not os.path.exists(dataset_folder):
        raise Exception("The dataset folder does not exist!")

    ground_truth_path = dataset_folder + "ground_truth.txt"
    if os.path.exists(ground_truth_path):
        with open(ground_truth_path, "r") as fp:
            ground_truth = json.load(fp)
    else:
        raise Exception("The ground truth file does not exist!")

    if detector == 'openpose':
        poses_path = dataset_folder + "poses3d_op.json"
    elif detector == 'mediapipe':
        poses_path = dataset_folder + "poses3d_mp.json"
    else:
        poses_path = dataset_folder + "poses3d_gt_10px.json"

    if os.path.exists(poses_path):
        with open(poses_path, "r") as fp:
            poses_3d = json.load(fp)
    else:
        raise Exception("The 3D poses file does not exist!")

    if detector == 'openpose':
        camera_dict_path = dataset_folder + "cameras_op.json"
    elif detector == 'mediapipe':
        camera_dict_path = dataset_folder + "cameras_mp.json"
    else:
        camera_dict_path = dataset_folder + "cameras_gt_10px.json"

    if os.path.exists(camera_dict_path):
        with open(camera_dict_path, "r") as fp:
            cameras = json.load(fp)
    else:
        raise Exception("The camera dictionary file does not exist!")

    ground_truth_path = dataset_folder + "ground_truth.txt"
    if os.path.exists(ground_truth_path):
        with open(ground_truth_path, "r") as fp:
            ground_truth = json.load(fp)
    else:
        raise Exception("The ground truth file does not exist!")

    # cameras = args['cameras_to_use']

    # --------------------------------------------------------
    # Calculate projections
    # --------------------------------------------------------
    e = {}
    for camera_key, camera in cameras.items():
        if camera_key in args['cameras_to_use']:
            e[camera_key] = {}
            for frame_key, frame in camera['frames'].items():
                window_name = camera_key + '_' + frame_key
                img = cv2.imread(frame['image_path'])
                # img = np.uint8(np.array(frame['image']))
                e[camera_key][frame_key] = {}
                for joint_key, joint in joints_mpi_inf_3dhp.items():
                    e[camera_key][frame_key][joint_key] = {}

                    # print(camera['extrinsics'])

                    # Project 3D point to camera
                    pts_in_world = np.ndarray((4, 1), dtype=float)
                    pts_in_world[0][0] = poses_3d[frame_key][joint_key]['X']
                    pts_in_world[1][0] = poses_3d[frame_key][joint_key]['Y']
                    pts_in_world[2][0] = poses_3d[frame_key][joint_key]['Z']
                    pts_in_world[3][0] = 1

                    pts_in_sensor = np.dot(camera['extrinsics'], pts_in_world)

                    pts_in_image, _, _ = projectToCamera(camera['intrinsics'], camera['distortion'], camera['width'],
                                                         camera['height'], pts_in_sensor[0:3, :])

                    # get 3d point coordinates
                    pts_in_world_gt = np.ndarray((4, 1), dtype=float)
                    pts_in_world_gt[0][0] = ground_truth[frame_key][joint_key]['pose']['x']
                    pts_in_world_gt[1][0] = ground_truth[frame_key][joint_key]['pose']['y']
                    pts_in_world_gt[2][0] = ground_truth[frame_key][joint_key]['pose']['z']
                    pts_in_world_gt[3][0] = 1

                    # Transform to the camera's coordinate frame
                    pts_in_sensor_gt = np.dot(camera['extrinsics'], pts_in_world_gt)

                    # Project 3D point to camera
                    pts_in_image_gt, _, _ = projectToCamera(np.array(camera['intrinsics']),
                                                            np.array(camera['distortion']),
                                                            camera['width'],
                                                            camera['height'], pts_in_sensor_gt[0:3, :])

                    # Compute distance from 2D detection of joint and 3D->2D projection
                    xpix_gt = pts_in_image_gt[0][0]
                    ypix_gt = pts_in_image_gt[1][0]

                    xpix_detected = pts_in_image[0][0]
                    ypix_detected = pts_in_image[1][0]

                    color = joint['color']

                    drawSquare2D(img, xpix_gt, ypix_gt, 10, color=color, thickness=3)
                    drawCross2D(img, xpix_detected, ypix_detected, 10, color=color, thickness=3)
                    # exit(0)
                    x_error = abs(xpix_detected - xpix_gt)
                    y_error = abs(ypix_detected - ypix_gt)
                    rmse = abs(math.dist([xpix_detected, ypix_detected], [xpix_gt, ypix_gt]))
                    e[camera_key][frame_key][joint_key]['x_error'] = x_error
                    e[camera_key][frame_key][joint_key]['y_error'] = y_error
                    e[camera_key][frame_key][joint_key]['rmse'] = rmse

                    if args['show_images']:
                        if detector == 'mediapipe':
                            img_gui = cv2.resize(img, None, fx=0.5, fy=0.5)
                            cv2.imshow(window_name, img_gui)
                        else:
                            cv2.imshow(window_name, img)
                if args['show_images']:
                    cv2.waitKey(0)

    # print(e)
    # -------------------------------------------------------------
    # Print output table
    # -------------------------------------------------------------
    table_header = ['Camera', 'Frame #', 'Joint', 'RMS (pix)', 'X err (pix)', 'Y err (pix)']
    # table = PrettyTable(table_header)
    if not args['extended_tables']:
        table = PrettyTable(table_header)
        for camera_key, camera in cameras.items():
            for frame_key, frame in camera['frames'].items():
                x_avg = 0
                y_avg = 0
                rmse_avg = 0
                total_joints = 0
                for joint_key, joint in joints_mpi_inf_3dhp.items():
                    # joint_key_detected = joint_key
                    joint_detected = frame['joints'][joint_key]
                    if not joint_detected['valid']:  # skip invalid joints
                        continue
                    x_avg = x_avg + e[camera_key][frame_key][joint_key]['x_error']
                    y_avg = y_avg + e[camera_key][frame_key][joint_key]['y_error']
                    rmse_avg = rmse_avg + e[camera_key][frame_key][joint_key]['rmse']
                    total_joints = total_joints + 1
                x_avg = x_avg / total_joints
                y_avg = y_avg / total_joints
                rmse_avg = rmse_avg / total_joints
                avg_row = [camera_key, frame_key, Fore.BLUE + Style.BRIGHT + 'Average' + Fore.BLACK + Style.NORMAL,
                           '%.3f' % rmse_avg,
                           '%.3f' % x_avg,
                           '%.3f' % y_avg]
                table.add_row(avg_row)
        print(table)

    else:
        for camera_key, camera in cameras.items():
            table = PrettyTable(table_header)
            for frame_key, frame in camera['frames'].items():
                x_avg = 0
                y_avg = 0
                rmse_avg = 0
                total_joints = 0
                for joint_key, joint in joints_mpi_inf_3dhp.items():
                    # joint_key_detected = joint['body25']
                    joint_detected = frame['joints'][joint_key]
                    if not joint_detected['valid']:  # skip invalid joints
                        continue
                        print(e[camera_key][frame_key][joint_key])
                    row = [camera_key, frame_key, joint_key,
                           '%.3f' % e[camera_key][frame_key][joint_key]['rmse'],
                           '%.3f' % e[camera_key][frame_key][joint_key]['x_error'],
                           '%.3f' % e[camera_key][frame_key][joint_key]['y_error']]
                    table.add_row(row)
                    x_avg = x_avg + e[camera_key][frame_key][joint_key]['x_error']
                    y_avg = y_avg + e[camera_key][frame_key][joint_key]['y_error']
                    rmse_avg = rmse_avg + e[camera_key][frame_key][joint_key]['rmse']
                    total_joints = total_joints + 1
                x_avg = x_avg / total_joints
                y_avg = y_avg / total_joints
                rmse_avg = rmse_avg / total_joints
                table.add_row(['--------', '--------', '-----------', '--------', '--------', '--------'])
                avg_row = [camera_key, frame_key, Fore.BLUE + Style.BRIGHT + 'Average' + Fore.BLACK + Style.NORMAL,
                           '%.3f' % rmse_avg,
                           '%.3f' % x_avg,
                           '%.3f' % y_avg]
                table.add_row(avg_row)
                table.add_row(['--------', '--------', '-----------', '--------', '--------', '--------'])
            print(table)

    table = PrettyTable(table_header)
    for camera_key, camera in cameras.items():
        for joint_key, joint in joints_mpi_inf_3dhp.items():
            total_frames = 0
            x_avg = 0
            y_avg = 0
            rmse_avg = 0
            for frame_key, frame in camera['frames'].items():
                joint_detected = frame['joints'][joint_key]
                if not joint_detected['valid']:  # skip invalid joints
                    continue
                x_avg = x_avg + e[camera_key][frame_key][joint_key]['x_error']
                y_avg = y_avg + e[camera_key][frame_key][joint_key]['y_error']
                rmse_avg = rmse_avg + e[camera_key][frame_key][joint_key]['rmse']
                total_frames = total_frames + 1
            x_avg = x_avg / total_frames
            y_avg = y_avg / total_frames
            rmse_avg = rmse_avg / total_frames
            avg_row = [camera_key, Fore.BLUE + Style.BRIGHT + 'Average' + Fore.BLACK + Style.NORMAL, joint_key,
                       '%.3f' % rmse_avg,
                       '%.3f' % x_avg,
                       '%.3f' % y_avg]
            table.add_row(avg_row)
    print(table)


    table = PrettyTable(table_header)
    for camera_key, camera in cameras.items():
        x_avg = 0
        y_avg = 0
        rmse_avg = 0
        for frame_key, frame in camera['frames'].items():
            for joint_key, joint in joints_mpi_inf_3dhp.items():
                joint_detected = frame['joints'][joint_key]
                if not joint_detected['valid']:  # skip invalid joints
                    continue
                x_avg = x_avg + e[camera_key][frame_key][joint_key]['x_error']
                y_avg = y_avg + e[camera_key][frame_key][joint_key]['y_error']
                rmse_avg = rmse_avg + e[camera_key][frame_key][joint_key]['rmse']
                total_frames = total_frames + 1
        x_avg = x_avg / total_frames
        y_avg = y_avg / total_frames
        rmse_avg = rmse_avg / total_frames
        avg_row = [camera_key, Fore.BLUE + Style.BRIGHT + 'Average' + Fore.BLACK + Style.NORMAL, Fore.BLUE + Style.BRIGHT + 'Average' + Fore.BLACK + Style.NORMAL,
                   '%.3f' % rmse_avg,
                   '%.3f' % x_avg,
                   '%.3f' % y_avg]
        table.add_row(avg_row)
    print(table)


if __name__ == "__main__":
    main()
