import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.animation as animation
from mpl_toolkits.mplot3d import Axes3D
import os
import json
import argparse
import math
import numpy as np
import cv2
import math
import sys

sys.path.insert(1, '../utils')
from links import joint_correspondence, joint_correspondence_human36m, links_evaluation, links_evaluation_human36m, \
    links_evaluation_human36m_mediapipe, joint_correspondence_human36m_mediapipe, links_mpi_inf_3dhp, \
    joints_mpi_inf_3dhp

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("-dataset", "--dataset_name",
                    help="Dataset name. Datasets must be inside hpe/images/. Please use -dataset human36m to optimize Human 3.6M.",
                    type=str, default='mpi')
    ap.add_argument("-s", "--section",
                    help="Human 3.6M section to optimize. Only works when the dataset is set to human36m",
                    type=str, default='S1')
    ap.add_argument("-a", "--action",
                    help="Human 3.6M action to optimize. Only works when the dataset is set to human36m",
                    type=str, default='Directions-1')
    ap.add_argument("-seq", "--sequence",
                    help="MPI sequence optimize. Only works when the dataset is set to human36m",
                    type=str, default='Seq1')
    ap.add_argument("-2d_detector", "--2d_detector",
                    help="2D detector used for 2D human pose estimation. Current options: openpose (default), mediapipe",
                    type=str, default='groundtruth')
    ap.add_argument("-fp", "--file_poses",
                    help="Name of file containing 3d poses",
                    type=str, default='poses3d_gt2.json')
    args = vars(ap.parse_args())

    # --------------------------------------------------------
    # Loading files, and verifications configurations
    # --------------------------------------------------------
    dataset_folder = '../../images/mpi-inf-3dhp/' + args['section'] + '/' + args['sequence'] + '/'
    links_evaluation = links_mpi_inf_3dhp

    # check if dataset folder exists:
    if not os.path.exists(dataset_folder):
        raise Exception("The dataset folder does not exist!")

    ground_truth_path = dataset_folder + "ground_truth.txt"
    if os.path.exists(ground_truth_path):
        with open(ground_truth_path, "r") as fp:
            ground_truth = json.load(fp)
    else:
        raise Exception("The ground truth file does not exist!")

    poses_path = dataset_folder + args['file_poses']

    if os.path.exists(poses_path):
        with open(poses_path, "r") as fp:
            poses_3d = json.load(fp)
    else:
        raise Exception("The 3D poses file does not exist!")

    ground_truth_path = dataset_folder + "ground_truth.txt"
    if os.path.exists(ground_truth_path):
        with open(ground_truth_path, "r") as fp:
            ground_truth = json.load(fp)
    else:
        raise Exception("The ground truth file does not exist!")

    frames = [] # for storing the generated images
    fig = plt.figure()
    for frame_key, frame in poses_3d.items():
        ax = fig.add_subplot(111, projection='3d')
        ax.set_xlabel('x', fontsize=20)
        ax.set_ylabel('y', fontsize=20)
        ax.set_zlabel('z', fontsize=20)
        X_vec = []
        Y_vec = []
        Z_vec = []
        X_vec_link = []
        Y_vec_link = []
        Z_vec_link = []
        for joint_key, joint in joints_mpi_inf_3dhp.items():
            joint_detected = frame[joint_key]
            point = frame[joint_key]
            X_vec.append(point['X'])
            Y_vec.append(point['Y'])
            Z_vec.append(point['Z'])
        # frames.append([ax.scatter(X_vec, Y_vec, Z_vec,'k',animated=True)])

        for link_name, link in links_mpi_inf_3dhp.items():
            joint0 = frame[link['parent']]
            joint1 = frame[link['child']]
            X0 = joint0['X']
            Y0 = joint0['Y']
            Z0 = joint0['Z']
            X1 = joint1['X']
            Y1 = joint1['Y']
            Z1 = joint1['Z']

            X_vec_link.append(X0)
            X_vec_link.append(X1)
            Y_vec_link.append(Y0)
            Y_vec_link.append(Y1)
            Z_vec_link.append(Z0)
            Z_vec_link.append(Z1)

            # Y_vec_link = []
            # Z_vec_link = []

            # frames.append([ax.plot([X0, X1], [Y0, Y1],
            #                  [Z0, Z1],
            #                  c=(0, 0, 0))])
        frames.append([ax.plot(X_vec_link,Y_vec_link,Z_vec_link,
                         c=(0, 0, 0))])
        # frames.append([plt.imshow(img[i], cmap=cm.Greys_r,animated=True)])

    ani = animation.ArtistAnimation(fig, frames, interval=50, blit=True,
                                    repeat_delay=1000)
    # ani.save('movie.mp4')
    plt.show()
    plt.waitforbuttonpress()

    # plt.pause(1)



if __name__ == "__main__":
    main()