import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.animation as animation
from mpl_toolkits.mplot3d import Axes3D
import os
import json
import argparse
import math
import numpy as np
import cv2
import math
import sys

sys.path.insert(1, '../utils')
from links import joint_correspondence, joint_correspondence_human36m, links_evaluation, links_evaluation_human36m, \
    links_evaluation_human36m_mediapipe, joint_correspondence_human36m_mediapipe, links_mpi_inf_3dhp, \
    joints_mpi_inf_3dhp

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("-dataset", "--dataset_name",
                    help="Dataset name. Datasets must be inside hpe/images/. Please use -dataset human36m to optimize Human 3.6M.",
                    type=str, default='mpi')
    ap.add_argument("-s", "--section",
                    help="Human 3.6M section to optimize. Only works when the dataset is set to human36m",
                    type=str, default='S1')
    ap.add_argument("-a", "--action",
                    help="Human 3.6M action to optimize. Only works when the dataset is set to human36m",
                    type=str, default='Directions-1')
    ap.add_argument("-seq", "--sequence",
                    help="MPI sequence optimize. Only works when the dataset is set to human36m",
                    type=str, default='Seq1')
    ap.add_argument("-2d_detector", "--2d_detector",
                    help="2D detector used for 2D human pose estimation. Current options: openpose (default), mediapipe",
                    type=str, default='groundtruth')
    ap.add_argument("-fp", "--file_poses",
                    help="Name of file containing 3d poses",
                    type=str, default='poses3d_gt2.json')
    args = vars(ap.parse_args())

    # --------------------------------------------------------
    # Loading files, and verifications configurations
    # --------------------------------------------------------
    dataset_folder = '../../images/mpi-inf-3dhp/' + args['section'] + '/' + args['sequence'] + '/'
    links_evaluation = links_mpi_inf_3dhp

    # check if dataset folder exists:
    if not os.path.exists(dataset_folder):
        raise Exception("The dataset folder does not exist!")

    ground_truth_path = dataset_folder + "ground_truth.txt"
    if os.path.exists(ground_truth_path):
        with open(ground_truth_path, "r") as fp:
            ground_truth = json.load(fp)
    else:
        raise Exception("The ground truth file does not exist!")

    poses_path = dataset_folder + args['file_poses']

    if os.path.exists(poses_path):
        with open(poses_path, "r") as fp:
            poses_3d = json.load(fp)
    else:
        raise Exception("The 3D poses file does not exist!")

    ground_truth_path = dataset_folder + "ground_truth.txt"
    if os.path.exists(ground_truth_path):
        with open(ground_truth_path, "r") as fp:
            ground_truth = json.load(fp)
    else:
        raise Exception("The ground truth file does not exist!")

    frames_total = 10000000000
    n_frames_camera = len(poses_3d)
    if n_frames_camera < frames_total:
        frames_total = n_frames_camera

    fig = plt.figure()
    ax1 = fig.add_subplot(1, 1, 1, projection='3d')
    ax1.set_title("3D pose")
    # ax1.set_xlim3d(-1, 1)
    # ax1.set_ylim3d(-3, -1)
    # ax1.set_zlim3d(0, 2)
    ax1.set_zlim3d(0, 5000)
    ax1.set_xlabel('x')
    ax1.set_ylabel('y')
    ax1.set_zlabel('z')
    ax1.view_init(azim=90, elev=-50)

    Xtotal = []
    Ytotal = []
    Ztotal = []
    Xtotal_gt = []
    Ytotal_gt = []
    Ztotal_gt = []
    colors_total = []
    xltotal = []
    yltotal = []
    zltotal = []


    for frame_key, frame in poses_3d.items():
        X_vec = []
        Y_vec = []
        Z_vec = []
        joint_colors = []
        xlines = []
        ylines = []
        zlines = []

        min_z = 1000
        for joint_key, joint in frame.items():
            point = poses_3d[frame_key][joint_key]
            b, g, r = tuple(c / 255 for c in point['color'])
            joint_colors.append((r, g, b))
            X_vec.append(point['X'])
            Y_vec.append(point['Y'])
            Z_vec.append(point['Z'])

        # if args['has_ground_truth']:
        #     X_vec_gt = []
        #     Y_vec_gt = []
        #     Z_vec_gt = []
        #     for joint_key, joint in ground_truth[frame_key].items():
        #         point = ground_truth[frame_key][joint_key]['pose']
        #         X_vec_gt.append(float(point['x']))
        #         Y_vec_gt.append(float(point['y']))
        #         Z_vec_gt.append(float(point['z']))
        #
        #     Xtotal_gt.append(X_vec_gt)
        #     Ytotal_gt.append(Y_vec_gt)
        #     Ztotal_gt.append(Z_vec_gt)

        Xtotal.append(X_vec)
        Ytotal.append(Y_vec)
        Ztotal.append(Z_vec)
        colors_total.append(joint_colors)

        for link_name, link in links_mpi_inf_3dhp.items():
            joint0 = poses_3d[frame_key][link['parent']]
            joint1 = poses_3d[frame_key][link['child']]
            X0 = joint0['X']
            Y0 = joint0['Y']
            Z0 = joint0['Z']
            X1 = joint1['X']
            Y1 = joint1['Y']
            Z1 = joint1['Z']
            xlines.append([X0, X1])
            ylines.append([Y0, Y1])
            zlines.append([Z0, Z1])

        # if args['has_ground_truth']:
        #     xlines_gt = []
        #     ylines_gt = []
        #     zlines_gt = []
        #
        #     for link_name, link in links_mpi_inf_3dhp.items():
        #         joint0 = ground_truth[frame_key][link['parent']]['pose']
        #         joint1 = ground_truth[frame_key][link['child']]['pose']
        #         X0 = joint0['x']
        #         Y0 = joint0['y']
        #         Z0 = joint0['z']
        #         X1 = joint1['x']
        #         Y1 = joint1['y']
        #         Z1 = joint1['z']
        #         xlines_gt.append([X0, X1])
        #         ylines_gt.append([Y0, Y1])
        #         zlines_gt.append([Z0, Z1])
        #
        #     xltotal_gt.append(xlines_gt)
        #     yltotal_gt.append(ylines_gt)
        #     zltotal_gt.append(zlines_gt)

        xltotal.append(xlines)
        yltotal.append(ylines)
        zltotal.append(zlines)
    # print(frames_total)
    # exit(0)
    def update_plot(frames):
        i = frames
        ax1.clear()
        # ax1.set_title("3D pose")

        maxZ=0
        maxX=0
        maxY=0

        ax1.set_xlim3d(-1000, 2000)
        ax1.set_ylim3d(0, 2000)
        ax1.set_zlim3d(-1000, 1000)
        # ax1.set_xlabel('x', fontsize=20)
        # ax1.set_ylabel('y', fontsize=20)
        # ax1.set_zlabel('z', fontsize=20)
        plt.setp(ax1.get_xticklabels(), visible=False)
        plt.setp(ax1.get_yticklabels(), visible=False)
        plt.setp(ax1.get_zticklabels(), visible=False)
        # loc = plticker.MultipleLocator(base=1)  # this locator puts ticks at regular intervals
        # ax1.xaxis.set_major_locator(loc)
        # ax1.yaxis.set_major_locator(loc)
        # ax1.zaxis.set_major_locator(loc)

        # ax1.axis('off')
        if 0 <= i < len(Ztotal):
            # continue
            # z_min = min(Ztotal[i])
            # ax1.set_zlim3d(z_min, z_min + 2)
            # print(z_min)
            # draw3Dcoordinatesystem(ax1, [], xc=0, yc=0, zc=min_z, size=0.5)

            ax1.scatter(Xtotal[i], Ytotal[i], Ztotal[i], c=colors_total[i])

            for n in range(0, len(xltotal[i]) - 1):
                ax1.plot(xltotal[i][n], yltotal[i][n], zltotal[i][n], c=(0.5, 0.5, 0.5))

            # if args['has_ground_truth']:
            #     ax1.scatter(Xtotal_gt[i], Ytotal_gt[i], Ztotal_gt[i], c='k')
            #
            #     for n in range(0, len(xltotal_gt[i]) - 1):
            #         ax1.plot(xltotal_gt[i][n], yltotal_gt[i][n], zltotal_gt[i][n], c=(0, 0, 0))


    ani = animation.FuncAnimation(fig, update_plot, frames=frames_total, interval=100, repeat=False)
    # plt.pause(1)
    # plt.show()
    ani.save(dataset_folder + '3d_2.mp4', bitrate=1000)



if __name__ == "__main__":
    main()