#!/usr/bin/env python3

import argparse
import math
import os
import json
import sys
from functools import partial
from copy import deepcopy
from pathlib import Path
import h5py
import imageio

import matplotlib.animation as animation
import matplotlib.ticker as plticker
import cv2
import numpy as np
import matplotlib.pyplot as plt
import yaml
from yaml import SafeLoader
from natsort import natsorted

if os.environ.get('USER') == 'mike':
    from OptimizationUtils import Optimizer
    from KeyPressManager import WindowManager
else:
    import OptimizationUtils.OptimizationUtils as OptimizationUtils
    from OptimizationUtils.OptimizationUtils import Optimizer
    from OptimizationUtils.KeyPressManager import WindowManager

from mpl_toolkits.mplot3d import Axes3D

sys.path.append(str(Path("../").resolve()))  # add previous folder to python path

from utils.transforms import getTransform, get_transform_tree_dict
from utils.draw import drawSquare2D, drawCross2D, drawCircle, drawDiagonalCross2D, draw3Dcoordinatesystem
from utils.links import links_ground_truth, joint_idxs, intial_estimate_joints, joints_human36m, \
    joints_mediapipe, links_human36m, links_mediapipe, links_openpose, joint_idxs_op_mpi, links_op_mpi, \
    links_mpi_inf_3dhp, joints_mpi_inf_3dhp
from utils.projection import projectToCamera
from utils.load_cam_params import getCamParamsHuman36m, getCamParamsMPI
from objective_function import objectiveFunction
from numba import jit, cuda


# @jit
# @cuda.jit
def main():
    # --------------------------------------------------------
    # Arguments
    # --------------------------------------------------------
    ap = argparse.ArgumentParser()
    ap.add_argument("-sff", "--skip_frame_to_frame_residuals", help="Use frame to frame residuals", action="store_true",
                    default=False)
    ap.add_argument("-sll", "--skip_link_length_residuals", help="Use link length residuals", action="store_true",
                    default=False)
    ap.add_argument("-gt", "--has_ground_truth", help="Use groundtruth in visualization", action="store_true",
                    default=False)
    ap.add_argument("-db", "--debug", help="Prints debug lines", action="store_true",
                    default=False)
    ap.add_argument("-si", "--show_images", help="Show optimization images", action="store_true",
                    default=False)
    ap.add_argument("-pe", "--phased_execution", help="After optimization finished keep windows running",
                    action="store_true",
                    default=False)
    ap.add_argument("-o", "--save_output", help="Save output json file with final skeleton and poses",
                    action="store_true",
                    default=False)
    ap.add_argument("-v", "--video", help="Saves .mp4 video in the dataset folder",
                    action="store_true",
                    default=False)
    ap.add_argument("-all", "--optimize_all",
                    help=" Use to optimize the entire dataset/video. If this argument is set, the program ignores the -frames flag.",
                    action="store_true",
                    default=False)
    ap.add_argument("-max", "--max_frames", help="Maximum number of frames when calibrating the entire video.",
                    type=int)
    ap.add_argument("-sf", "--start_frame", help="Frame to start optimization.", type=int, default=0)
    ap.add_argument("-2d_detector", "--2d_detector",
                    help="2D detector used for 2D human pose estimation. Current options: openpose (default), mediapipe, groundtruth",
                    type=str, default='openpose')
    ap.add_argument("-dataset", "--dataset_name",
                    help="Dataset name. Datasets must be inside hpe/images/. Please use -dataset human36m to optimize Human 3.6M. Please use -dataset mpi to optimize Mpi Inf 3DHP",
                    type=str, default='mpi')
    ap.add_argument("-s", "--section",
                    help="Human 3.6M section to optimize. Only works when the dataset is set to human36m",
                    type=str, default='S1')
    ap.add_argument("-a", "--action",
                    help="Human 3.6M action to optimize. Only works when the dataset is set to human36m",
                    type=str, default='Directions-1')
    ap.add_argument("-seq", "--sequence",
                    help="MPI sequence optimize. Only works when the dataset is set to human36m",
                    type=str, default='Seq1')
    ap.add_argument("-xacro", "--xacro_name", help="Name of xacro file. Must be inside folder xacros/.", type=str,
                    default="novo.urdf.xacro")
    ap.add_argument("-2d_poses", "--2d_poses",
                    help="Specify the name of the .npy file containing the 2d poses. Must be inside dataset_name/cache/camera_x/filename.npy",
                    type=str)
    ap.add_argument("-frames", "--frames_to_use", help="Frames to use in optimization. ", nargs='+', type=int,
                    default=[1, 2, 3])
    ap.add_argument("-cams", "--cameras_to_use", help="Cameras to use in optimization", nargs='+',
                    default=['camera_0', 'camera_4', 'camera_5', 'camera_8'])
    # default = ['54138969', '55011271', '58860488', '60457274']

    args = vars(ap.parse_args())
    # --------------------------------------------------------
    # Configuration and verifications
    # --------------------------------------------------------
    world_frame = 'world'

    dataset_folder = '../../images/' + args['dataset_name'] + '/'

    human36m = False
    mpi = False
    if args['dataset_name'] == "human36m":
        human36m = True
        dataset_folder = '../../images/human36m/processed/' + args['section'] + '/' + args['action'] + '/'
        # print(dataset_folder)
    elif args['dataset_name'] == "mpi":
        mpi = True
        dataset_folder = '../../images/mpi-inf-3dhp/' + args['section'] + '/' + args['sequence'] + '/'

    detector = 'groundtruth'
    if args['2d_detector'] == "mediapipe":
        detector = 'mediapipe'
        links = links_mediapipe
        skeleton_joints = joints_mediapipe
    elif args['2d_detector'] == "openpose":
        detector = 'openpose'
        links = links_openpose
        skeleton_joints = joint_idxs
        if mpi == True:
            skeleton_joints = joint_idxs_op_mpi
            links = links_op_mpi
    else:
        links = links_mpi_inf_3dhp
        skeleton_joints = joints_mpi_inf_3dhp

    print("Running with detector " + detector + " and dataset " + args['dataset_name'])

    # check if dataset folder exists:
    if not os.path.exists(dataset_folder):
        raise Exception("The dataset folder does not exist!")
    # dataset_folder = dataset_folder

    if human36m == False and mpi == False:
        xacro_name = args['xacro_name']
        file_xacro = "../xacros/" + xacro_name

        if not os.path.exists(file_xacro):
            raise Exception(
                "The xacro file does not exist! The xacro file must be inside the hpe/scripts/xacros folder.")
        transform_dict = get_transform_tree_dict(file_xacro)
    else:
        if human36m == True:
            cam_params = getCamParamsHuman36m([args['section']], args['cameras_to_use'], [args['action']])
        if mpi == True:
            cam_params = getCamParamsMPI(args['section'], args['sequence'], args['cameras_to_use'])

    # Define cameras to use
    cameras = {}
    for camera in args["cameras_to_use"]:
        # check if sensor exists:
        if human36m == False and mpi == False:
            if os.path.exists(dataset_folder + camera):
                cameras[camera] = {}
            else:
                raise Exception("The given sensor " + camera + " does not exist in the dataset!")
        else:
            if os.path.exists(dataset_folder + 'imageSequence/' + camera):
                cameras[camera] = {}
            else:
                raise Exception("The given sensor " + camera + " does not exist in the dataset!")

    idxs_to_use = []
    # Define frames to use
    # if not args['optimize_all']:
    #     idxs_to_use = ['img_000001.jpg']
    # else:
    #
    # exit(0)

    joints_to_use = []
    for _, link in links.items():  # derive al joints to use based on the parent and childs
        joints_to_use.append(link['parent'])
        joints_to_use.append(link['child'])

    joints_to_use = list(set(joints_to_use))  # remove repetitions

    # Paths to files
    first_camera = True
    frames_total = 10000000000
    for camera_key, camera in cameras.items():
        if human36m == False and mpi == False:
            camera['folder'] = dataset_folder + camera_key + '/'
            camera['folder_skeleton'] = dataset_folder + 'cache/' + camera_key + '/'
        else:
            camera['folder'] = dataset_folder + 'imageSequence/' + camera_key + '/'
            camera['folder_skeleton'] = dataset_folder + 'cache/' + camera_key + '/'
        camera['files'] = natsorted(os.listdir(camera['folder']))
        if args['2d_poses'] and os.path.exists(camera['folder_skeleton'] + args['2d_poses']):
            camera['file_skeleton'] = camera['folder_skeleton'] + args['2d_poses']
        elif detector == 'openpose':
            camera['file_skeleton'] = camera['folder_skeleton'] + '2d_pose_op.npy'
        elif detector == 'mediapipe':
            camera['file_skeleton'] = camera['folder_skeleton'] + '2d_pose_mediapipe.npy'
        elif detector == 'groundtruth':
            camera['file_skeleton'] = camera['folder_skeleton'] + '2d_pose_gt.npy'
        else:
            raise Exception("2D detector not available!")

        n_frames_camera = len(camera['files'])
        if n_frames_camera < frames_total:
            frames_total = n_frames_camera

        if not os.path.exists(camera['file_skeleton']):
            raise Exception(
                "The 2D poses file does not exist! The file must be inside the images/dataset_name/cache/camera_x/ folder.")

        camera['frame_id'] = camera_key + '_rgb_optical_frame'

        start_frame = args['start_frame']

        if first_camera:
            if args['optimize_all']:
                idxs_to_use = camera['files']
            elif args['max_frames']:
                max = int(args['max_frames'])
                idxs_to_use = camera['files'][start_frame:start_frame + max]
                # print(idxs_to_use)
            elif args['frames_to_use']:
                for idx in args['frames_to_use']:
                    idxs_to_use.append(camera['files'][idx])

        # exit(0)

        # Check if frame exists
        for idx in idxs_to_use:
            # print(camera['folder'] + str(idx) + '.png')
            if human36m == False and mpi == False:
                if not os.path.exists(camera['folder'] + str(idx) + '.png'):
                    raise Exception("The frame " + str(idx) + " does not exist in the dataset!")
            else:
                if not os.path.exists(camera['folder'] + str(idx)):
                    raise Exception("The frame " + str(idx) + " does not exist in the dataset!")

        first_camera = False
        # TODO change to frame_id

    # Read camera intrinsics
    for camera_key, camera in cameras.items():
        if human36m == False and mpi == False:
            yaml_folder = dataset_folder + camera_key + '_sim.yaml'
            with open(yaml_folder) as f:
                data = yaml.load(f, Loader=SafeLoader)
            camera['intrinsics'] = np.reshape(data['camera_matrix']['data'], (3, 3))
            camera['distortion'] = np.reshape(data['distortion_coefficients']['data'], (5, 1))
            camera['width'] = data['image_width']
            camera['height'] = data['image_height']
        else:
            camera['intrinsics'] = cam_params[camera_key]['K']
            camera['distortion'] = cam_params[camera_key]['dist']
            if human36m == True:
                if camera_key == '54138969' or camera_key == '60457274':
                    camera['width'] = 1000
                    camera['height'] = 1002
                else:
                    camera['width'] = 1000
                    camera['height'] = 1000
            if mpi == True:
                camera['width'] = cam_params[camera_key]['width']
                camera['height'] = cam_params[camera_key]['height']

    # Read camera extrinsics
    for camera_key, camera in cameras.items():
        if human36m == False and mpi == False:
            camera['extrinsics'] = getTransform(camera['frame_id'], world_frame, transform_dict)
        else:
            camera['extrinsics'] = cam_params[camera_key]['T']

    # Read images and files
    for _, camera in cameras.items():
        poses_np = np.load(camera['file_skeleton'], allow_pickle=True)
        camera['frames'] = {}
        # print(zip(idxs_to_use, poses_np.tolist()[:len(idxs_to_use)]))
        # for i in zip(idxs_to_use, poses_np.tolist()[:len(idxs_to_use)]):
        #     print(i)
        for idx, (file, pose_np) in enumerate(
                zip(idxs_to_use, poses_np.tolist()[start_frame:start_frame + len(idxs_to_use)])):
            # print(idx)
            idx = idx + start_frame
            # print(idx)
            frame = cv2.imread(camera['folder'] + file)

            joints = {}
            for joint_idx_key, joint_data_value in skeleton_joints.items():
                joint_idx_value = joint_data_value['idx']
                if joint_idx_key not in joints_to_use:
                    continue

                point = pose_np[joint_idx_value]
                x, y = point[0], point[1]

                if len(point) == 3:
                    confidence = point[2]
                    valid = (not x == 0) and (not y == 0)
                    if valid == False:
                        color = (0, 0, 255)
                    else:
                        color = joint_data_value['color']
                    joints[joint_idx_key] = {'x': x, 'y': y, 'confidence': confidence, 'valid': valid,
                                             'x_proj': 0.0, 'y_proj': 0.0, 'color': color}
                else:
                    valid = (not x == 0) and (not y == 0)
                    if valid == False:
                        color = (0, 0, 255)
                    else:
                        color = joint_data_value['color']
                    joints[joint_idx_key] = {'x': x, 'y': y, 'valid': valid,
                                             'x_proj': 0.0, 'y_proj': 0.0, 'color': color}
                # TODO implement a better initialization

            camera['frames'][str(idx)] = {'image_path': camera['folder'] + file, 'image': frame, 'joints': joints}

    # exit(0)

    selected_camera_key = list(cameras.keys())[0]  # select the first in the list arbitrarily
    selected_camera = cameras[selected_camera_key]

    # read ground truth
    if args['has_ground_truth']:
        # if human36m == False:
        ground_truth_path = dataset_folder + "ground_truth.txt"
        if os.path.exists(ground_truth_path):
            with open(ground_truth_path, "r") as fp:
                ground_truth = json.load(fp)
        else:
            raise Exception("The ground truth file does not exist! Run without -gt flag.")

    # create the poses 3D dictionary.
    # NOTE must have the same size as the frames in the sensors

    # Since all the frames lists in the cameras are of the same size, we can use the first in the list.
    poses3d = {}
    for frame_key, frame in selected_camera['frames'].items():
        if int(frame_key) > frames_total:
            continue
        poses3d[frame_key] = {}
        for joint_key, joint in frame['joints'].items():
            poses3d[frame_key][joint_key] = {'X': 0.0, 'Y': 0.0, 'Z': 0.0, 'color': joint['color']}

    # -----------------------------------------
    # OPTIMIZATION
    # -----------------------------------------

    print(idxs_to_use)
    # Configure optimizer
    opt = Optimizer()
    opt.addDataModel('args', args)
    opt.addDataModel('cameras', cameras)
    opt.addDataModel('poses3d', poses3d)
    opt.addDataModel('groundtruth', ground_truth)
    opt.addDataModel('frame_list', idxs_to_use)

    errors_per_iteration = {}
    opt.addDataModel('errors', errors_per_iteration)

    # opt.getNumberOfFunctionCallsPerIteration(optimization_options):

    # opt.addDataModel('iteration', iteration)

    # ----------------------------------------------
    # Setters and getters
    # ----------------------------------------------
    def getJointXYZ(data, frame_key, joint_key):
        d = data[frame_key][joint_key]
        X, Y, Z = d['X'], d['Y'], d['Z']
        return [X, Y, Z]

    def setJointXYZ(data, values, frame_key, joint_key):
        X, Y, Z = values[0], values[1], values[2]
        d = data[frame_key][joint_key]
        d['X'] = X
        d['Y'] = Y
        d['Z'] = Z

    # Test getters and setters
    # setJointXYZ(cameras, [33,44,55], 'camera_2', '200', 'RWrist')
    # pt = getJointXYZ(cameras, 'camera_2', '200', 'RShoulder')
    # print(pt)
    # pt = getJointXYZ(cameras, 'camera_3', '200', 'RShoulder')
    # print(pt)

    # ----------------------------------------------
    # Create the parameters
    # ----------------------------------------------
    for frame_key, frame in selected_camera['frames'].items():
        for joint_key, joint in frame['joints'].items():
            group_name = 'frame_' + frame_key + '_joint_' + joint_key
            opt.pushParamVector(group_name, data_key='poses3d',
                                getter=partial(getJointXYZ, frame_key=frame_key, joint_key=joint_key),
                                setter=partial(setJointXYZ, frame_key=frame_key, joint_key=joint_key),
                                suffix=['_X', '_Y', '_Z'])

    # opt.printParameters()
    # ----------------------------------------------
    # Define the objective function
    # ----------------------------------------------
    opt.setObjectiveFunction(objectiveFunction)
    # ----------------------------------------------
    # Define the residuals
    # ----------------------------------------------

    # Projection residuals
    for camera_key, camera in cameras.items():
        for frame_key, frame in camera['frames'].items():
            for joint_key, joint in frame['joints'].items():
                if not joint['valid']:  # skip invalid joints
                    continue

                parameter_pattern = 'frame_' + frame_key + '_joint_' + joint_key
                residual_key = 'projection_sensor_' + camera_key + '_frame_' + frame_key + '_joint_' + joint_key

                params = opt.getParamsContainingPattern(pattern=parameter_pattern)  # get all weight related parameters
                opt.pushResidual(name=residual_key, params=params)

    # Frame distance residuals
    if not args['skip_frame_to_frame_residuals']:
        for initial_frame_key, final_frame_key in zip(list(poses3d.keys())[:-1], list(poses3d.keys())[1:]):
            if not int(final_frame_key) - int(initial_frame_key) == 1:  # frames are not consecutive
                # print('Frame ' + initial_frame_key + ' and ' + final_frame_key + ' are not consecutive.')
                continue

            initial_pose = poses3d[initial_frame_key]

            for joint_key in initial_pose.keys():
                residual_key = 'consecutive_frame_' + initial_frame_key + '_frame_' + final_frame_key + '_joint_' + joint_key

                params = opt.getParamsContainingPattern(
                    pattern='frame_' + initial_frame_key + '_joint_' + joint_key + '_')  # get parameters for frame initial_frame
                params.extend(opt.getParamsContainingPattern(
                    pattern='frame_' + final_frame_key + '_joint_' + joint_key + '_'))  # get parameters for frame final_frame
                opt.pushResidual(name=residual_key, params=params)

    # Link length residuals
    if not args['skip_link_length_residuals']:
        for link_key, link in links.items():
            # print('Link ' + link['parent'] + '-' + link['child'])

            # for frame_key, frame in poses3d.items():  # compute residual as distance from reference link length

            params = opt.getParamsContainingPattern(pattern='joint_' + link['parent'] + '_')
            params.extend(
                opt.getParamsContainingPattern(pattern='joint_' + link['child'] + '_'))

            residual_key = 'link_length_' + 'joint_' + link['parent'] + '_joint_' + link['child']
            opt.pushResidual(name=residual_key, params=params)

    opt.printResiduals()

    # ----------------------------------------------
    # Compute sparse matrix
    # ----------------------------------------------
    opt.computeSparseMatrix()
    opt.printSparseMatrix()

    # ----------------------------------------------
    # Visualization function
    # ----------------------------------------------

    # if args['optimize_all'] == True:
    #     args['show_images'] = False

    # Setup visualization
    if args['show_images']:
        # Draw links in images
        for _, camera in cameras.items():
            for frame_idx, (_, frame) in enumerate(camera['frames'].items()):
                frame['image_gui'] = deepcopy(frame['image'])
                frame.items()
                for link_name, link in links.items():
                    joint0 = frame['joints'][link['parent']]
                    joint1 = frame['joints'][link['child']]

                    if not joint0['valid'] or not joint1['valid']:
                        continue

                    x0, y0 = int(joint0['x']), int(joint0['y'])
                    x1, y1 = int(joint1['x']), int(joint1['y'])

                    cv2.line(frame['image_gui'], (x0, y0), (x1, y1), (128, 128, 128), 3)

        # Draw 2D detections on images
        for camera_key, camera in cameras.items():
            for _, frame in camera['frames'].items():
                for joint_idx, joint in frame['joints'].items():
                    x, y = int(joint['x']), int(joint['y'])
                    color = joint['color']
                    if not joint['valid']:
                        print(joint_idx + " is not valid in camera " + camera_key)
                        color = (255, 0, 0)
                        # continue

                    square_size = 5 + (20 - 5) * joint['confidence']
                    cv2.putText(frame['image_gui'], joint_idx, (x, y), cv2.FONT_HERSHEY_PLAIN, 0.5, (0, 255, 0))
                    drawSquare2D(frame['image_gui'], x, y, square_size, color=color, thickness=3)

        # Draw projected groundtruth in images
        if args['has_ground_truth']:
            for _, camera in cameras.items():
                for frame_key, frame in camera['frames'].items():
                    for joint_key, joint in ground_truth[frame_key].items():

                        pts_in_world = np.ndarray((4, 1), dtype=float)
                        pts_in_world[0][0] = ground_truth[frame_key][joint_key]['pose']['x']
                        pts_in_world[1][0] = ground_truth[frame_key][joint_key]['pose']['y']
                        pts_in_world[2][0] = ground_truth[frame_key][joint_key]['pose']['z']
                        pts_in_world[3][0] = 1

                        pts_in_sensor = np.dot(camera['extrinsics'], pts_in_world)

                        # Project 3D point to camera
                        if mpi == True:
                            pts_in_image, _, _ = projectToCamera(np.array(camera['intrinsics']),
                                                                 None,
                                                                 camera['width'],
                                                                 camera['height'], pts_in_sensor[0:3, :])
                        else:
                            pts_in_image, _, _ = projectToCamera(np.array(camera['intrinsics']),
                                                                 np.array(camera['distortion']),
                                                                 camera['width'],
                                                                 camera['height'], pts_in_sensor[0:3, :])

                        xpix_projected = pts_in_image[0][0]
                        ypix_projected = pts_in_image[1][0]

                        x, y = int(xpix_projected), int(ypix_projected)
                        color = (0, 0, 0)

                        drawDiagonalCross2D(frame['image_gui'], x, y, 10, color=color, thickness=3)

        # Create visualization in images
        for camera_key, camera in cameras.items():
            for frame_key, frame in camera['frames'].items():
                window_name = camera_key + '_' + frame_key
                # cv2.namedWindow(window_name, cv2.WINDOW_FULLSCREEN)
                cv2.namedWindow(window_name, cv2.WINDOW_NORMAL)
                cv2.imshow(window_name, frame['image_gui'])

        # Setup 3D visualization
        plot_handles = {}
        for frame_key, frame in selected_camera['frames'].items():
            plot_handles[frame_key] = {}
            plot_handles[frame_key]['figure'] = plt.figure()
            plot_handles[frame_key]['figure'].suptitle('Frame #' + str(frame_key), fontsize=14)
            plot_handles[frame_key]['axes'] = plot_handles[frame_key]['figure'].add_subplot(111, projection='3d')
            plot_handles[frame_key]['axes'].set_xlim3d(-1, 1)
            plot_handles[frame_key]['axes'].set_ylim3d(-1, 1)
            plot_handles[frame_key]['axes'].set_zlim3d(0, 2)
            plot_handles[frame_key]['axes'].set_xlabel('x', fontsize=20)
            plot_handles[frame_key]['axes'].set_ylabel('y', fontsize=20)
            plot_handles[frame_key]['axes'].set_zlabel('z', fontsize=20)
            # loc = plticker.MultipleLocator(base=1)  # this locator puts ticks at regular intervals
            # plot_handles[frame_key]['axes'].xaxis.set_major_locator(loc)
            # plot_handles[frame_key]['axes'].yaxis.set_major_locator(loc)
            # plot_handles[frame_key]['axes'].zaxis.set_major_locator(loc)
            plt.setp(plot_handles[frame_key]['axes'].get_xticklabels(), visible=False)
            plt.setp(plot_handles[frame_key]['axes'].get_yticklabels(), visible=False)
            plt.setp(plot_handles[frame_key]['axes'].get_zticklabels(), visible=False)

        figures = []
        for plot_handle_key, plot_handle in plot_handles.items():
            figures.append(plot_handle['figure'])

        wm = WindowManager(figs=figures)
        opt.addDataModel('window_manager', wm)

        # Draw floor coordinate system
        for frame_key, frame in selected_camera['frames'].items():
            X_vec = []
            Y_vec = []
            Z_vec = []
            joint_colors = []

            min_z = 1000
            for joint_key, joint in frame['joints'].items():
                point = poses3d[frame_key][joint_key]
                X_vec.append(point['X'])
                Y_vec.append(point['Y'])
                Z_vec.append(point['Z'])
                b, g, r = tuple(c / 255 for c in point['color'])
                joint_colors.append((r, g, b))

                if point['Z'] < min_z:
                    min_z = point['Z']

            # ground truth
            if args['has_ground_truth']:
                for joint_key, joint in ground_truth[frame_key].items():
                    point = ground_truth[frame_key][joint_key]['pose']
                    X_vec.append(float(point['x']))
                    Y_vec.append(float(point['y']))
                    Z_vec.append(float(point['z']))
                    b, g, r = (0, 0, 0)
                    joint_colors.append((r, g, b))

            plot_handles[frame_key]['joint_handle'] = plot_handles[frame_key]['axes'].scatter(X_vec, Y_vec, Z_vec,
                                                                                              c=joint_colors)  # Draw N points
            plot_handles[frame_key]['coordinate_system'] = draw3Dcoordinatesystem(plot_handles[frame_key]['axes'], [],
                                                                                  xc=0, yc=0, zc=min_z, size=0.2)

        for frame_key, frame in selected_camera['frames'].items():
            # Draw ground_truth in 3D
            if args['has_ground_truth']:
                plot_handles[frame_key]['ground_truth'] = {}
                if human36m:
                    for link_name, link in links_human36m.items():
                        plot_handles[frame_key]['ground_truth'][link_name] = {}
                        joint0 = ground_truth[frame_key][link['parent']]['pose']
                        joint1 = ground_truth[frame_key][link['child']]['pose']
                        X0 = joint0['x']
                        Y0 = joint0['y']
                        Z0 = joint0['z']
                        X1 = joint1['x']
                        Y1 = joint1['y']
                        Z1 = joint1['z']

                        plot_handles[frame_key]['ground_truth'][link_name]['link_handle_gt'] = plot_handles[frame_key][
                            'axes'].plot([X0, X1], [Y0, Y1],
                                         [Z0, Z1],
                                         c=(0, 0, 0))
                elif mpi:
                    for link_name, link in links_mpi_inf_3dhp.items():
                        plot_handles[frame_key]['ground_truth'][link_name] = {}
                        joint0 = ground_truth[frame_key][link['parent']]['pose']
                        joint1 = ground_truth[frame_key][link['child']]['pose']
                        X0 = joint0['x']
                        Y0 = joint0['y']
                        Z0 = joint0['z']
                        X1 = joint1['x']
                        Y1 = joint1['y']
                        Z1 = joint1['z']

                        plot_handles[frame_key]['ground_truth'][link_name]['link_handle_gt'] = plot_handles[frame_key][
                            'axes'].plot([X0, X1], [Y0, Y1],
                                         [Z0, Z1],
                                         c=(0, 0, 0))
                else:
                    for link_name, link in links_ground_truth.items():
                        plot_handles[frame_key]['ground_truth'][link_name] = {}
                        joint0 = ground_truth[frame_key][link['parent']]['pose']
                        joint1 = ground_truth[frame_key][link['child']]['pose']
                        X0 = joint0['x']
                        Y0 = joint0['y']
                        Z0 = joint0['z']
                        X1 = joint1['x']
                        Y1 = joint1['y']
                        Z1 = joint1['z']

                        plot_handles[frame_key]['ground_truth'][link_name]['link_handle_gt'] = \
                            plot_handles[frame_key][
                                'axes'].plot([X0, X1], [Y0, Y1],
                                             [Z0, Z1],
                                             c=(0, 0, 0))

            # Plot links in 3D
            for link_name, link in links.items():
                plot_handles[frame_key][link_name] = {}
                joint0 = poses3d[frame_key][link['parent']]
                joint1 = poses3d[frame_key][link['child']]
                X0 = joint0['X']
                Y0 = joint0['Y']
                Z0 = joint0['Z']
                X1 = joint1['X']
                Y1 = joint1['Y']
                Z1 = joint1['Z']

                plot_handles[frame_key][link_name]['link_handle'] = plot_handles[frame_key]['axes'].plot([X0, X1],
                                                                                                         [Y0, Y1],
                                                                                                         [Z0, Z1],
                                                                                                         c=(
                                                                                                             0.5, 0.5,
                                                                                                             0.5))
        plt.draw()
        plt.pause(0.0001)

    # cv2.waitKey(0)

    @jit
    def visualization_function(data):
        cameras = data['cameras']
        poses3d = data['poses3d']
        wm = data['window_manager']
        # print('Visualization function called')

        for camera_key, camera in cameras.items():
            for frame_key, frame in camera['frames'].items():
                image = deepcopy(frame['image_gui'])
                for joint_key, joint in frame['joints'].items():
                    x = joint['x_proj']
                    y = joint['y_proj']
                    if not joint['valid']:
                        thick = 10
                    else:
                        thick = 3
                    drawCross2D(image, x, y, 15, color=joint['color'], thickness=thick)
                    window_name = camera_key + '_' + frame_key
                    cv2.imshow(window_name, image)

        for frame_key, frame in selected_camera['frames'].items():
            X_vec = []
            Y_vec = []
            Z_vec = []

            min_z = 1000000
            max_z = 0

            for joint_key, joint in frame['joints'].items():
                point = poses3d[frame_key][joint_key]
                # print(point)
                X_vec.append(point['X'])
                Y_vec.append(point['Y'])
                Z_vec.append(point['Z'])

                if point['Z'] < min_z:
                    min_z = point['Z']
                if point['Z'] > max_z:
                    max_z = point['Z']

            avg_x = sum(X_vec) / len(X_vec)
            avg_y = sum(Y_vec) / len(Y_vec)
            avg_z = sum(Z_vec) / len(Z_vec)

            plot_handles[frame_key]['axes'].set_xlim3d(avg_x - 10 * avg_x, avg_x + 10 * avg_x)
            plot_handles[frame_key]['axes'].set_ylim3d(avg_y - 10 * avg_y, avg_y + 10 * avg_y)
            plot_handles[frame_key]['axes'].set_zlim3d(1.1 * min_z, 1.1 * max_z)
            # plot_handles[frame_key]['axes'].set_zlim3d(avg_z - avg_z / 2, avg_z + avg_z / 2)
            # loc = plticker.MultipleLocator(base=1)  # this locator puts ticks at regular intervals
            # plot_handles[frame_key]['axes'].xaxis.set_major_locator(loc)
            # plot_handles[frame_key]['axes'].yaxis.set_major_locator(loc)
            # plot_handles[frame_key]['axes'].zaxis.set_major_locator(loc)

            if args['has_ground_truth']:
                for joint_key, joint in ground_truth[frame_key].items():
                    point = ground_truth[frame_key][joint_key]['pose']
                    X_vec.append(float(point['x']))
                    Y_vec.append(float(point['y']))
                    Z_vec.append(float(point['z']))

            plot_handles[frame_key]['joint_handle']._offsets3d = (X_vec, Y_vec, Z_vec)
            draw3Dcoordinatesystem(plot_handles[frame_key]['axes'], plot_handles[frame_key]['coordinate_system'], 0, 0,
                                   min_z, 0.2, update=True)
            # plot_handles[frame_key]['coordinate_system_z'][0].set_3d_properties(zs=[min_z, 0.1])

            plt.draw()

            for link_name, link in links.items():
                joint0 = poses3d[frame_key][link['parent']]
                joint1 = poses3d[frame_key][link['child']]

                X0 = joint0['X']
                Y0 = joint0['Y']
                Z0 = joint0['Z']
                X1 = joint1['X']
                Y1 = joint1['Y']
                Z1 = joint1['Z']

                plot_handles[frame_key][link_name]['link_handle'][0].set_xdata([X0, X1])
                plot_handles[frame_key][link_name]['link_handle'][0].set_ydata([Y0, Y1])
                plot_handles[frame_key][link_name]['link_handle'][0].set_3d_properties(zs=[Z0, Z1])

            if args['has_ground_truth']:
                if human36m:
                    for link_name, link in links_human36m.items():
                        joint0 = ground_truth[frame_key][link['parent']]['pose']
                        joint1 = ground_truth[frame_key][link['child']]['pose']
                        X0 = joint0['x']
                        Y0 = joint0['y']
                        Z0 = joint0['z']
                        X1 = joint1['x']
                        Y1 = joint1['y']
                        Z1 = joint1['z']
                        # exit(0)

                        plot_handles[frame_key]['ground_truth'][link_name]['link_handle_gt'][0].set_xdata([X0, X1])
                        plot_handles[frame_key]['ground_truth'][link_name]['link_handle_gt'][0].set_ydata([Y0, Y1])
                        plot_handles[frame_key]['ground_truth'][link_name]['link_handle_gt'][0].set_3d_properties(
                            zs=[Z0, Z1])
                elif mpi:
                    for link_name, link in links_mpi_inf_3dhp.items():
                        joint0 = ground_truth[frame_key][link['parent']]['pose']
                        joint1 = ground_truth[frame_key][link['child']]['pose']
                        X0 = joint0['x']
                        Y0 = joint0['y']
                        Z0 = joint0['z']
                        X1 = joint1['x']
                        Y1 = joint1['y']
                        Z1 = joint1['z']
                        # exit(0)

                        plot_handles[frame_key]['ground_truth'][link_name]['link_handle_gt'][0].set_xdata([X0, X1])
                        plot_handles[frame_key]['ground_truth'][link_name]['link_handle_gt'][0].set_ydata([Y0, Y1])
                        plot_handles[frame_key]['ground_truth'][link_name]['link_handle_gt'][0].set_3d_properties(
                            zs=[Z0, Z1])
                else:
                    for link_name, link in links_ground_truth.items():
                        joint0 = ground_truth[frame_key][link['parent']]['pose']
                        joint1 = ground_truth[frame_key][link['child']]['pose']
                        X0 = joint0['x']
                        Y0 = joint0['y']
                        Z0 = joint0['z']
                        X1 = joint1['x']
                        Y1 = joint1['y']
                        Z1 = joint1['z']
                        # exit(0)

                        plot_handles[frame_key]['ground_truth'][link_name]['link_handle_gt'][0].set_xdata([X0, X1])
                        plot_handles[frame_key]['ground_truth'][link_name]['link_handle_gt'][0].set_ydata([Y0, Y1])
                        plot_handles[frame_key]['ground_truth'][link_name]['link_handle_gt'][0].set_3d_properties(
                            zs=[Z0, Z1])

            plt.draw()
            # plt.pause(0.00001)
        # cv2.waitKey(10)
        wm.waitForKey(time_to_wait=0.001, verbose=False, message='')

    if args['show_images']:
        opt.setVisualizationFunction(visualization_function, always_visualize=True, niterations=10)
        opt.internal_visualization = False

    # ----------------------------------------------
    # Start the xacros
    # ----------------------------------------------
    if args['phased_execution'] and args['show_images']:
        wm.waitForKey(verbose=True, message='Ready to start optimization. Press \'c\' to continue  skhfksdhfkshd.')

    # opt.callObjectiveFunction() # just to tests
    opt.startOptimization(optimization_options={'x_scale': 'jac', 'ftol': 1e-8,
                                                'xtol': 1e-8, 'gtol': 1e-8, 'diff_step': None})  # 'max_nfev': 1}

    # ----------------------------------------------
    # Generate final video
    # ----------------------------------------------
    if args['optimize_all'] and args['video']:
        fig = plt.figure(figsize=(50, 10))

        ax2 = fig.add_subplot(3, 2, 1)
        ax3 = fig.add_subplot(3, 2, 2)
        ax4 = fig.add_subplot(3, 2, 3)
        ax5 = fig.add_subplot(3, 2, 4)
        ax1 = fig.add_subplot(2, 1, 2, projection='3d')
        camera_axis = []
        camera_axis.append((ax2, ax3, ax4, ax5))
        ax1.set_title("3D pose")
        ax1.set_xlim3d(-1, 1)
        ax1.set_ylim3d(-3, -1)
        ax1.set_zlim3d(0, 2)
        ax1.set_xlabel('x')
        ax1.set_ylabel('y')
        ax1.set_zlabel('z')
        Xtotal = []
        Ytotal = []
        Ztotal = []
        Xtotal_gt = []
        Ytotal_gt = []
        Ztotal_gt = []
        colors_total = []
        xltotal = []
        yltotal = []
        zltotal = []
        xltotal_gt = []
        yltotal_gt = []
        zltotal_gt = []
        images_total = []

        for camera_key, camera in cameras.items():
            list_for_camera = []
            for frame_key, frame in camera['frames'].items():
                list_for_camera.append(cameras[camera_key]['frames'][frame_key]['image'])

            images_total.append(list_for_camera)

        for frame_key, frame in selected_camera['frames'].items():
            X_vec = []
            Y_vec = []
            Z_vec = []
            joint_colors = []
            xlines = []
            ylines = []
            zlines = []

            min_z = 1000
            for joint_key, joint in frame['joints'].items():
                point = poses3d[frame_key][joint_key]
                b, g, r = tuple(c / 255 for c in point['color'])
                joint_colors.append((r, g, b))
                X_vec.append(point['X'])
                Y_vec.append(point['Y'])
                Z_vec.append(point['Z'])

            if args['has_ground_truth']:
                X_vec_gt = []
                Y_vec_gt = []
                Z_vec_gt = []
                for joint_key, joint in ground_truth[frame_key].items():
                    point = ground_truth[frame_key][joint_key]['pose']
                    X_vec_gt.append(float(point['x']))
                    Y_vec_gt.append(float(point['y']))
                    Z_vec_gt.append(float(point['z']))

                Xtotal_gt.append(X_vec_gt)
                Ytotal_gt.append(Y_vec_gt)
                Ztotal_gt.append(Z_vec_gt)

            Xtotal.append(X_vec)
            Ytotal.append(Y_vec)
            Ztotal.append(Z_vec)
            colors_total.append(joint_colors)

            for link_name, link in links.items():
                joint0 = poses3d[frame_key][link['parent']]
                joint1 = poses3d[frame_key][link['child']]
                X0 = joint0['X']
                Y0 = joint0['Y']
                Z0 = joint0['Z']
                X1 = joint1['X']
                Y1 = joint1['Y']
                Z1 = joint1['Z']
                xlines.append([X0, X1])
                ylines.append([Y0, Y1])
                zlines.append([Z0, Z1])

            if args['has_ground_truth']:
                xlines_gt = []
                ylines_gt = []
                zlines_gt = []

                if human36m == False:
                    for link_name, link in links_ground_truth.items():
                        joint0 = ground_truth[frame_key][link['parent']]['pose']
                        joint1 = ground_truth[frame_key][link['child']]['pose']
                        X0 = joint0['x']
                        Y0 = joint0['y']
                        Z0 = joint0['z']
                        X1 = joint1['x']
                        Y1 = joint1['y']
                        Z1 = joint1['z']
                        xlines_gt.append([X0, X1])
                        ylines_gt.append([Y0, Y1])
                        zlines_gt.append([Z0, Z1])
                else:
                    for link_name, link in links_human36m.items():
                        joint0 = ground_truth[frame_key][link['parent']]['pose']
                        joint1 = ground_truth[frame_key][link['child']]['pose']
                        X0 = joint0['x']
                        Y0 = joint0['y']
                        Z0 = joint0['z']
                        X1 = joint1['x']
                        Y1 = joint1['y']
                        Z1 = joint1['z']
                        xlines_gt.append([X0, X1])
                        ylines_gt.append([Y0, Y1])
                        zlines_gt.append([Z0, Z1])

                xltotal_gt.append(xlines_gt)
                yltotal_gt.append(ylines_gt)
                zltotal_gt.append(zlines_gt)

            xltotal.append(xlines)
            yltotal.append(ylines)
            zltotal.append(zlines)

        def update_plot(frames):
            i = frames
            ax1.clear()
            ax2.clear()
            ax3.clear()
            ax4.clear()
            ax4.clear()
            ax1.set_title("3D pose")
            ax1.set_xlim3d(-1, 1)
            ax1.set_ylim3d(-1, 1)
            ax1.set_zlim3d(0, 2)
            ax1.set_xlabel('x', fontsize=20)
            ax1.set_ylabel('y', fontsize=20)
            ax1.set_zlabel('z', fontsize=20)
            plt.setp(ax1.get_xticklabels(), visible=False)
            plt.setp(ax1.get_yticklabels(), visible=False)
            plt.setp(ax1.get_zticklabels(), visible=False)
            # loc = plticker.MultipleLocator(base=1)  # this locator puts ticks at regular intervals
            # ax1.xaxis.set_major_locator(loc)
            # ax1.yaxis.set_major_locator(loc)
            # ax1.zaxis.set_major_locator(loc)

            # ax1.axis('off')
            z_min = min(Ztotal[i])
            ax1.set_zlim3d(z_min, z_min + 2)
            # print(z_min)
            draw3Dcoordinatesystem(ax1, [], xc=0, yc=0, zc=min_z, size=0.5)

            ax1.scatter(Xtotal[i], Ytotal[i], Ztotal[i], c=colors_total[i])

            for n in range(0, len(xltotal[i]) - 1):
                ax1.plot(xltotal[i][n], yltotal[i][n], zltotal[i][n], c=(0.5, 0.5, 0.5))

            if args['has_ground_truth']:
                ax1.scatter(Xtotal_gt[i], Ytotal_gt[i], Ztotal_gt[i], c='k')

                for n in range(0, len(xltotal_gt[i]) - 1):
                    ax1.plot(xltotal_gt[i][n], yltotal_gt[i][n], zltotal_gt[i][n], c=(0, 0, 0))

            for n in range(0, len(images_total)):
                camera_axis[0][n].imshow(cv2.cvtColor(images_total[n][i], cv2.COLOR_BGR2RGB))
                camera_axis[0][n].set_title("camera_" + str(n + 1))
                camera_axis[0][n].set_aspect('equal', adjustable='box')
                # camera_axis[0][n].axis('square')
                camera_axis[0][n].axis('off')

            # fig.tight_layout()

        ani = animation.FuncAnimation(fig, update_plot, frames=frames_total, interval=200, repeat=False)

        if args['save_output'] and args['video']:
            ani.save(dataset_folder + '3d.mp4')

        plt.draw()

    # if args['debug']:
    #     plt.figure()
    #     iterations = []
    #     avgs = []
    #     avgs_joint = {}
    #     joint_colors = []
    #     append_colors = True
    #     for iteration_idx in errors_per_iteration:
    #         avg_per_iteration = 0
    #         iteration = errors_per_iteration[iteration_idx]
    #         n_frames = 0
    #         for frame_idx in iteration:
    #             frame = iteration[frame_idx]
    #             avg_per_frame = 0
    #             n_joints = 0
    #             for joint_idx in joints_to_use:
    #                 avgs_joint[joint_idx] = []
    #                 if append_colors == True:
    #                     color = selected_camera['frames'][frame_idx]['joints'][joint_idx]['color']
    #                     b, g, r = tuple(c / 255 for c in color)
    #                     joint_colors.append((r, g, b))
    #
    #                 rmse = frame[joint_idx]['rmse']
    #                 avg_per_frame += rmse
    #                 n_joints += 1
    #             append_colors = False
    #             avg_per_frame = avg_per_frame / n_joints
    #             avg_per_iteration += avg_per_frame
    #             n_frames += 1
    #         avg_per_iteration = avg_per_iteration / n_frames
    #         iterations.append(iteration_idx)
    #         avgs.append(avg_per_iteration)
    #
    #     for iteration_idx in errors_per_iteration:
    #         iteration = errors_per_iteration[iteration_idx]
    #         for joint_idx in joints_to_use:
    #             avg_per_joint = 0
    #             n_frames = 0
    #             for frame_idx in iteration:
    #                 frame = iteration[frame_idx]
    #                 rmse = frame[joint_idx]['rmse']
    #                 avg_per_joint += rmse
    #                 n_frames += 1
    #             avg_per_joint = avg_per_joint / n_frames
    #             avgs_joint[joint_idx].append(avg_per_joint)
    #             # print(
    #             #     "iteration #" + str(iteration_idx) + ' joint ' + str(joint_idx) + ': rmse = ' + str(avg_per_joint))
    #
    #
    #
    #     markers = ['.', 'o', 'v', '^', '<', '>', '1', '2', '3', '4', '8', 's', 'p', 'P', '*', 'h', 'H', '+', 'x',
    #                'X', 'D', 'd', '.', ',', 'o', 'v', '^', '<', '>', '1', '2', '3', '4', '8', 's', 'p', 'P', '*', 'h',
    #                'H', '+', 'x', 'X', 'D', 'd']
    #     i = 0
    #     for joint_idx, joint in avgs_joint.items():
    #         plt.plot(iterations, avgs_joint[joint_idx],
    #                  c=joint_colors[i], label=joint_idx, marker=markers[i], markersize=7)
    #         i += 1
    #
    #     plt.plot(iterations, avgs,
    #              'k-', label='Average per iteration', marker='o', linewidth=3, markersize=8)
    #
    #     plt.xlabel('Iterations')
    #     plt.ylabel('RMSE')
    #     plt.legend()
    #     plt.grid()
    #
    #     if not args['show_images']:
    #         plt.show()
    #     else:
    #         plt.draw()

    if args['save_output']:
        print("Saving output files...")
        # appendix = args['2d_poses'][8:-4]
        #
        # for camera_key, camera in cameras.items():
        #     cameras[camera_key]['intrinsics'] = cameras[camera_key]['intrinsics'].tolist()
        #     cameras[camera_key]['extrinsics'] = cameras[camera_key]['extrinsics'].tolist()
        #     cameras[camera_key]['distortion'] = cameras[camera_key]['distortion'].tolist()
        #     for frame_key, frame in camera['frames'].items():
        #         cameras[camera_key]['frames'][frame_key]['image'] = cameras[camera_key]['frames'][frame_key][
        #             'image'].tolist()
        #         if 'image_gui' in cameras[camera_key]['frames'][frame_key]:
        #             del cameras[camera_key]['frames'][frame_key]['image_gui']
        #         del cameras[camera_key]['frames'][frame_key]['image']
        if detector == 'mediapipe':
            with open(dataset_folder + "poses3d_mp.json", "w") as fp:
                json.dump(poses3d, fp)
            with open(dataset_folder + "cameras_mp.json", "w") as fp:
                json.dump(cameras, fp)
        elif detector == 'openpose':
            with open(dataset_folder + "poses3d_op.json", "w") as fp:
                json.dump(poses3d, fp)
            with open(dataset_folder + "cameras_op.json", "w") as fp:
                json.dump(cameras, fp)
        else:
            with open(dataset_folder + "poses3d_gt.json", "w") as fp:
                json.dump(poses3d, fp)
            # with open(dataset_folder + "cameras_gt.json", "w") as fp:
            #     json.dump(cameras, fp)

        # appendix = appendix + '_' + str(args['start_frame']) + '_' + str(
        #     int(args['start_frame']) + int(args['max_frames']))
        #
        # if detector == 'mediapipe':
        #     with open(dataset_folder + "poses3d_mp" + appendix + ".json", "w") as fp:
        #         json.dump(poses3d, fp)
        #     with open(dataset_folder + "cameras_mp" + appendix + ".json", "w") as fp:
        #         json.dump(cameras, fp)
        # elif detector == 'openpose':
        #     with open(dataset_folder + "poses3d_op" + appendix + ".json", "w") as fp:
        #         json.dump(poses3d, fp)
        #     with open(dataset_folder + "cameras_op" + appendix + ".json", "w") as fp:
        #         json.dump(cameras, fp)
        # else:
        #     with open(dataset_folder + "poses3d_gt" + appendix + ".json", "w") as fp:
        #         json.dump(poses3d, fp)
        #     with open(dataset_folder + "cameras_gt" + appendix + ".json", "w") as fp:
        #         json.dump(cameras, fp)
        # print("Files saved!")

    if args['phased_execution'] and args['show_images']:
        wm.waitForKey(verbose=True, message='Optimization finished. Press \'q\' to quit.')


if __name__ == "__main__":
    main()
