import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.animation as animation
from mpl_toolkits.mplot3d import Axes3D
import os
import json
import argparse
import math
import numpy as np
import cv2
import math
import sys
from natsort import natsorted


sys.path.insert(1, '../utils')
from links import joint_correspondence, joint_correspondence_human36m, links_evaluation, links_evaluation_human36m, \
    links_evaluation_human36m_mediapipe, joint_correspondence_human36m_mediapipe, links_mpi_inf_3dhp, \
    joints_mpi_inf_3dhp

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("-dataset", "--dataset_name",
                    help="Dataset name. Datasets must be inside hpe/images/. Please use -dataset human36m to optimize Human 3.6M.",
                    type=str, default='mpi')
    ap.add_argument("-s", "--section",
                    help="Human 3.6M section to optimize. Only works when the dataset is set to human36m",
                    type=str, default='S1')
    ap.add_argument("-a", "--action",
                    help="Human 3.6M action to optimize. Only works when the dataset is set to human36m",
                    type=str, default='Directions-1')
    ap.add_argument("-seq", "--sequence",
                    help="MPI sequence optimize. Only works when the dataset is set to human36m",
                    type=str, default='Seq1')
    ap.add_argument("-2d_detector", "--2d_detector",
                    help="2D detector used for 2D human pose estimation. Current options: openpose (default), mediapipe",
                    type=str, default='groundtruth')
    ap.add_argument("-fp", "--file_poses",
                    help="Name of file containing 3d poses",
                    type=str, default='poses3d_gt2.json')
    ap.add_argument("-cams", "--cameras", help="Choose camera detection. This camera must be present in the dataset.",
                    nargs='+', default='camera_8')
    args = vars(ap.parse_args())

    # --------------------------------------------------------
    # Loading files, and verifications configurations
    # --------------------------------------------------------
    dataset_folder = '../../images/mpi-inf-3dhp/' + args['section'] + '/' + args['sequence'] + '/'
    links_evaluation = links_mpi_inf_3dhp

    # check if dataset folder exists:
    if not os.path.exists(dataset_folder):
        raise Exception("The dataset folder does not exist!")

    ground_truth_path = dataset_folder + "ground_truth.txt"
    if os.path.exists(ground_truth_path):
        with open(ground_truth_path, "r") as fp:
            ground_truth = json.load(fp)
    else:
        raise Exception("The ground truth file does not exist!")

    poses_path = dataset_folder + args['file_poses']

    if os.path.exists(poses_path):
        with open(poses_path, "r") as fp:
            poses_3d = json.load(fp)
    else:
        raise Exception("The 3D poses file does not exist!")

    ground_truth_path = dataset_folder + "ground_truth.txt"
    if os.path.exists(ground_truth_path):
        with open(ground_truth_path, "r") as fp:
            ground_truth = json.load(fp)
    else:
        raise Exception("The ground truth file does not exist!")

    dataset_folder = '../../images/mpi-inf-3dhp/' + args['section'] + '/' + args['sequence'] + '/imageSequence/'

    frames_total = 10000000000
    n_frames_camera = len(poses_3d)
    if n_frames_camera < frames_total:
        frames_total = n_frames_camera

    fig = plt.figure()

    i=0
    images_total = []
    camera_folder=dataset_folder + args['cameras']
    for filename in natsorted(os.listdir(camera_folder)):
        if 150 <= i < 300:
            img = cv2.cvtColor(cv2.imread(os.path.join(camera_folder, filename)), cv2.COLOR_BGR2RGB)
            img = cv2.resize(img, (int(img.shape[1] / 4), int(img.shape[0] / 4)))
            images_total.append(img)
        i += 1

    def update_plot(n):
        plt.imshow(images_total[n])
        plt.axis('off')


    ani = animation.FuncAnimation(fig, update_plot, frames=frames_total, interval=100, repeat=False)
    # plt.pause(1)
    # plt.show()
    ani.save(dataset_folder + '2d'+ args['cameras'] +'.mp4')



if __name__ == "__main__":
    main()