#!/usr/bin/env python3

import argparse
import math
import os
import json
import sys
from functools import partial
from copy import deepcopy
from pathlib import Path
import h5py
import imageio

import matplotlib.animation as animation
import matplotlib.ticker as plticker
import cv2
import numpy as np
import matplotlib.pyplot as plt
import yaml
from yaml import SafeLoader
from natsort import natsorted

if os.environ.get('USER') == 'mike':
    from OptimizationUtils import Optimizer
    from KeyPressManager import WindowManager
else:
    import OptimizationUtils.OptimizationUtils as OptimizationUtils
    from OptimizationUtils.OptimizationUtils import Optimizer
    from OptimizationUtils.KeyPressManager import WindowManager

from mpl_toolkits.mplot3d import Axes3D

sys.path.append(str(Path("../").resolve()))  # add previous folder to python path

from utils.transforms import getTransform, get_transform_tree_dict
from utils.draw import drawSquare2D, drawCross2D, drawCircle, drawDiagonalCross2D, draw3Dcoordinatesystem
from utils.links import links_ground_truth, joint_idxs, intial_estimate_joints, joints_human36m, \
    joints_mediapipe, links_human36m, links_mediapipe, links_openpose, joint_idxs_op_mpi, links_op_mpi, \
    links_mpi_inf_3dhp, joints_mpi_inf_3dhp
from utils.projection import projectToCamera
from utils.load_cam_params import getCamParamsHuman36m, getCamParamsMPI
from utils.utils import generate_frame_list
from objective_function import objectiveFunction


def main():
    # --------------------------------------------------------
    # Arguments
    # --------------------------------------------------------
    ap = argparse.ArgumentParser()
    ap.add_argument("-sff", "--skip_frame_to_frame_residuals", help="Use frame to frame residuals", action="store_true",
                    default=False)
    ap.add_argument("-sll", "--skip_link_length_residuals", help="Use link length residuals", action="store_true",
                    default=False)
    ap.add_argument("-gt", "--has_ground_truth", help="Use groundtruth in visualization", action="store_true",
                    default=False)
    ap.add_argument("-db", "--debug", help="Prints debug lines", action="store_true",
                    default=False)
    ap.add_argument("-pt", "--print_time", help="Prints times", action="store_true",
                    default=False)
    ap.add_argument("-si", "--show_images", help="Show optimization images", action="store_true",
                    default=False)
    ap.add_argument("-s3d", "--show_3d", help="Show optimization 3D visualization", action="store_true",
                    default=False)
    ap.add_argument("-pe", "--phased_execution", help="After optimization finished keep windows running",
                    action="store_true",
                    default=False)
    ap.add_argument("-o", "--save_output", help="Save output json file with final skeleton and poses",
                    action="store_true",
                    default=False)
    ap.add_argument("-v", "--video", help="Saves .mp4 video in the dataset folder",
                    action="store_true",
                    default=False)
    # ap.add_argument("-all", "--optimize_all",
    #                 help=" Use to optimize the entire dataset/video. If this argument is set, the program ignores the -frames flag.",
    #                 action="store_true",
    #                 default=False)
    ap.add_argument("-max", "--max_frames", help="Maximum number of frames when calibrating the entire video.",
                    type=int)
    ap.add_argument("-n_ff", "--number_of_frame_to_frame", help="Number of previous frames to use in optimization.",
                    type=int, default=20)
    ap.add_argument("-sf", "--start_frame", help="Frame to start optimization.", type=int, default=0)
    ap.add_argument("-2d_detector", "--2d_detector",
                    help="2D detector used for 2D human pose estimation. Current options: openpose (default), mediapipe, groundtruth",
                    type=str, default='openpose')
    ap.add_argument("-dataset", "--dataset_name",
                    help="Dataset name. Datasets must be inside hpe/images/. Please use -dataset human36m to optimize Human 3.6M. Please use -dataset mpi to optimize Mpi Inf 3DHP",
                    type=str, default='mpi')
    ap.add_argument("-s", "--section",
                    help="Human 3.6M section to optimize. Only works when the dataset is set to human36m",
                    type=str, default='S1')
    ap.add_argument("-a", "--action",
                    help="Human 3.6M action to optimize. Only works when the dataset is set to human36m",
                    type=str, default='Directions-1')
    ap.add_argument("-seq", "--sequence",
                    help="MPI sequence optimize. Only works when the dataset is set to human36m",
                    type=str, default='Seq1')
    ap.add_argument("-xacro", "--xacro_name", help="Name of xacro file. Must be inside folder xacros/.", type=str,
                    default="novo.urdf.xacro")
    ap.add_argument("-2d_poses", "--2d_poses",
                    help="Specify the name of the .npy file containing the 2d poses. Must be inside dataset_name/cache/camera_x/filename.npy",
                    type=str)
    # ap.add_argument("-frames", "--frames_to_use", help="Frames to use in optimization. ", nargs='+', type=int,
    #                 default=[1, 2, 3])
    ap.add_argument("-cams", "--cameras_to_use", help="Cameras to use in optimization", nargs='+',
                    default=['camera_0', 'camera_4', 'camera_5', 'camera_8'])
    # default = ['54138969', '55011271', '58860488', '60457274']

    args = vars(ap.parse_args())
    # --------------------------------------------------------
    # Configuration and verifications
    # --------------------------------------------------------
    world_frame = 'world'

    dataset_folder = '../../images/' + args['dataset_name'] + '/'

    human36m = False
    mpi = False
    if args['dataset_name'] == "human36m":
        human36m = True
        dataset_folder = '../../images/human36m/processed/' + args['section'] + '/' + args['action'] + '/'
        # print(dataset_folder)
    elif args['dataset_name'] == "mpi":
        mpi = True
        dataset_folder = '../../images/mpi-inf-3dhp/' + args['section'] + '/' + args['sequence'] + '/'

    detector = 'groundtruth'
    if args['2d_detector'] == "mediapipe":
        detector = 'mediapipe'
        links = links_mediapipe
        skeleton_joints = joints_mediapipe
    elif args['2d_detector'] == "openpose":
        detector = 'openpose'
        links = links_openpose
        skeleton_joints = joint_idxs
        if mpi == True:
            skeleton_joints = joint_idxs_op_mpi
            links = links_op_mpi
    else:
        links = links_mpi_inf_3dhp
        skeleton_joints = joints_mpi_inf_3dhp

    print("Running with detector " + detector + " and dataset " + args['dataset_name'])

    # check if dataset folder exists:
    if not os.path.exists(dataset_folder):
        raise Exception("The dataset folder does not exist!")
    # dataset_folder = dataset_folder

    if human36m == False and mpi == False:
        xacro_name = args['xacro_name']
        file_xacro = "../xacros/" + xacro_name

        if not os.path.exists(file_xacro):
            raise Exception(
                "The xacro file does not exist! The xacro file must be inside the hpe/scripts/xacros folder.")
        transform_dict = get_transform_tree_dict(file_xacro)
    else:
        if human36m == True:
            cam_params = getCamParamsHuman36m([args['section']], args['cameras_to_use'], [args['action']])
        if mpi == True:
            cam_params = getCamParamsMPI(args['section'], args['sequence'], args['cameras_to_use'])

    # Define cameras to use
    cameras = {}
    for camera in args["cameras_to_use"]:
        # check if sensor exists:
        if human36m == False and mpi == False:
            if os.path.exists(dataset_folder + camera):
                cameras[camera] = {}
            else:
                raise Exception("The given sensor " + camera + " does not exist in the dataset!")
        else:
            if os.path.exists(dataset_folder + 'imageSequence/' + camera):
                cameras[camera] = {}
            else:
                raise Exception("The given sensor " + camera + " does not exist in the dataset!")


    idxs_to_use = []

    joints_to_use = []
    for _, link in links.items():  # derive al joints to use based on the parent and childs
        joints_to_use.append(link['parent'])
        joints_to_use.append(link['child'])

    joints_to_use = list(set(joints_to_use))  # remove repetitions

    # Paths to files
    first_camera = True
    frames_total = 10000000000
    for camera_key, camera in cameras.items():
        if human36m == False and mpi == False:
            camera['folder'] = dataset_folder + camera_key + '/'
            camera['folder_skeleton'] = dataset_folder + 'cache/' + camera_key + '/'
        else:
            camera['folder'] = dataset_folder + 'imageSequence/' + camera_key + '/'
            camera['folder_skeleton'] = dataset_folder + 'cache/' + camera_key + '/'
        camera['files'] = natsorted(os.listdir(camera['folder']))
        if args['2d_poses'] and os.path.exists(camera['folder_skeleton'] + args['2d_poses']):
            camera['file_skeleton'] = camera['folder_skeleton'] + args['2d_poses']
        elif detector == 'openpose':
            camera['file_skeleton'] = camera['folder_skeleton'] + '2d_pose_op.npy'
        elif detector == 'mediapipe':
            camera['file_skeleton'] = camera['folder_skeleton'] + '2d_pose_mediapipe.npy'
        elif detector == 'groundtruth':
            camera['file_skeleton'] = camera['folder_skeleton'] + '2d_pose_gt.npy'
        else:
            raise Exception("2D detector not available!")

        n_frames_camera = len(camera['files'])
        if n_frames_camera < frames_total:
            frames_total = n_frames_camera

        if not os.path.exists(camera['file_skeleton']):
            raise Exception(
                "The 2D poses file does not exist! The file must be inside the images/dataset_name/cache/camera_x/ folder.")

        camera['frame_id'] = camera_key + '_rgb_optical_frame'

        start_frame = args['start_frame']

        if first_camera:
            idxs_to_use = camera['files']
            if args['max_frames']:
                max = int(args['max_frames'])
                idxs_to_use = camera['files'][start_frame:start_frame + max]

        # Check if frame exists
        for idx in idxs_to_use:
            # print(camera['folder'] + str(idx) + '.png')
            if human36m == False and mpi == False:
                if not os.path.exists(camera['folder'] + str(idx) + '.png'):
                    raise Exception("The frame " + str(idx) + " does not exist in the dataset!")
            else:
                if not os.path.exists(camera['folder'] + str(idx)):
                    raise Exception("The frame " + str(idx) + " does not exist in the dataset!")

        first_camera = False
        # TODO change to frame_id

    # Read camera intrinsics
    for camera_key, camera in cameras.items():
        if human36m == False and mpi == False:
            yaml_folder = dataset_folder + camera_key + '_sim.yaml'
            with open(yaml_folder) as f:
                data = yaml.load(f, Loader=SafeLoader)
            camera['intrinsics'] = np.reshape(data['camera_matrix']['data'], (3, 3))
            camera['distortion'] = np.reshape(data['distortion_coefficients']['data'], (5, 1))
            camera['width'] = data['image_width']
            camera['height'] = data['image_height']
        else:
            camera['intrinsics'] = cam_params[camera_key]['K']
            camera['distortion'] = cam_params[camera_key]['dist']
            if human36m == True:
                if camera_key == '54138969' or camera_key == '60457274':
                    camera['width'] = 1000
                    camera['height'] = 1002
                else:
                    camera['width'] = 1000
                    camera['height'] = 1000
            if mpi == True:
                camera['width'] = cam_params[camera_key]['width']
                camera['height'] = cam_params[camera_key]['height']

    # Read camera extrinsics
    for camera_key, camera in cameras.items():
        if human36m == False and mpi == False:
            camera['extrinsics'] = getTransform(camera['frame_id'], world_frame, transform_dict)
        else:
            camera['extrinsics'] = cam_params[camera_key]['T']

    # Read images and files
    for _, camera in cameras.items():
        poses_np = np.load(camera['file_skeleton'], allow_pickle=True)
        camera['frames'] = {}
        for idx, (file, pose_np) in enumerate(
                zip(idxs_to_use, poses_np.tolist()[start_frame:start_frame + len(idxs_to_use)])):
            idx = idx + start_frame
            frame = cv2.imread(camera['folder'] + file)

            joints = {}
            for joint_idx_key, joint_data_value in skeleton_joints.items():
                joint_idx_value = joint_data_value['idx']
                if joint_idx_key not in joints_to_use:
                    continue

                point = pose_np[joint_idx_value]
                x, y = point[0], point[1]

                if len(point) == 3:
                    confidence = point[2]
                    valid = (not x == 0) and (not y == 0)
                    if valid == False:
                        color = (0, 0, 255)
                    else:
                        color = joint_data_value['color']
                    joints[joint_idx_key] = {'x': x, 'y': y, 'confidence': confidence, 'valid': valid,
                                             'x_proj': 0.0, 'y_proj': 0.0, 'color': color}
                else:
                    valid = (not x == 0) and (not y == 0)
                    if valid == False:
                        color = (0, 0, 255)
                    else:
                        color = joint_data_value['color']
                    joints[joint_idx_key] = {'x': x, 'y': y, 'valid': valid,
                                             'x_proj': 0.0, 'y_proj': 0.0, 'color': color}
                # TODO implement a better initialization

            camera['frames'][str(idx)] = {'image_path': camera['folder'] + file, 'image': frame, 'joints': joints}

    selected_camera_key = list(cameras.keys())[0]  # select the first in the list arbitrarily
    selected_camera = cameras[selected_camera_key]

    # read ground truth
    if args['has_ground_truth']:
        # if human36m == False:
        ground_truth_path = dataset_folder + "ground_truth.txt"
        if os.path.exists(ground_truth_path):
            with open(ground_truth_path, "r") as fp:
                ground_truth = json.load(fp)
        else:
            raise Exception("The ground truth file does not exist! Run without -gt flag.")

    # create the poses 3D dictionary.
    # NOTE must have the same size as the frames in the sensors

    # for frame_key in idxs_to_use:
    # Since all the frames lists in the cameras are of the same size, we can use the first in the list.
    poses3d = {}
    previous_frame_key = ''
    this_frame_key=''
    frame_list = []
    for frame_key, frame in selected_camera['frames'].items():
        this_frame_key=frame_key
        if int(frame_key) > frames_total:
            continue
        poses3d[frame_key] = {}
        for joint_key, joint in frame['joints'].items():
            if int(frame_key) == start_frame:
                poses3d[frame_key][joint_key] = {'X': 0.0, 'Y': 0.0, 'Z': 0.0, 'color': joint['color']}
            elif int(frame_key) != start_frame:
                poses3d[frame_key][joint_key] = poses3d[previous_frame_key][joint_key]
        previous_frame_key = frame_key

        frame_list= generate_frame_list(start_frame, frame_key,args['number_of_frame_to_frame']
)
        #
        # frame_key = int(frame_key)
        # if frame_key == start_frame:
        #     frame_list = [frame_key]
        # elif frame_key == start_frame + 1:
        #     frame_list = [frame_key - 1, frame_key]
        # elif frame_key == start_frame + 2:
        #     frame_list = [frame_key - 2, frame_key - 1, frame_key]
        # elif frame_key == start_frame + 3:
        #     frame_list = [frame_key - 3, frame_key - 2, frame_key - 1, frame_key]
        # else:
        #     frame_list = [frame_key - 4, frame_key - 3, frame_key - 2, frame_key - 1, frame_key]
        # # else:
        #     frame_list = [frame_key - 5, frame_key - 4, frame_key - 3, frame_key - 2, frame_key - 1, frame_key]

        # frame_list = [str(element) for element in frame_list]

        # frame_key = str(frame_key)

        # -----------------------------------------
        # OPTIMIZATION
        # -----------------------------------------
        # Configure optimizer
        opt = Optimizer()
        opt.addDataModel('args', args)
        opt.addDataModel('cameras', cameras)
        opt.addDataModel('poses3d', poses3d)
        opt.addDataModel('frame_list', frame_list)
        opt.addDataModel('groundtruth', ground_truth)

        errors_per_iteration = {}
        opt.addDataModel('errors', errors_per_iteration)

        # opt.getNumberOfFunctionCallsPerIteration(optimization_options):

        # opt.addDataModel('iteration', iteration)

        # ----------------------------------------------
        # Setters and getters
        # ----------------------------------------------
        def getJointXYZ(data, frame_key, joint_key):
            d = data[frame_key][joint_key]
            X, Y, Z = d['X'], d['Y'], d['Z']
            return [X, Y, Z]

        def setJointXYZ(data, values, frame_key, joint_key):
            X, Y, Z = values[0], values[1], values[2]
            d = data[frame_key][joint_key]
            d['X'] = X
            d['Y'] = Y
            d['Z'] = Z

        # Test getters and setters
        # setJointXYZ(cameras, [33,44,55], 'camera_2', '200', 'RWrist')
        # pt = getJointXYZ(cameras, 'camera_2', '200', 'RShoulder')
        # print(pt)
        # pt = getJointXYZ(cameras, 'camera_3', '200', 'RShoulder')
        # print(pt)

        # ----------------------------------------------
        # Create the parameters
        # ----------------------------------------------
        for frame_key, frame in selected_camera['frames'].items():
            if frame_key not in frame_list:
                continue
            for joint_key, joint in frame['joints'].items():
                group_name = 'frame_' + frame_key + '_joint_' + joint_key
                opt.pushParamVector(group_name, data_key='poses3d',
                                    getter=partial(getJointXYZ, frame_key=frame_key, joint_key=joint_key),
                                    setter=partial(setJointXYZ, frame_key=frame_key, joint_key=joint_key),
                                    suffix=['_X', '_Y', '_Z'])

        # opt.printParameters()
        # ----------------------------------------------
        # Define the objective function
        # ----------------------------------------------
        opt.setObjectiveFunction(objectiveFunction)
        # ----------------------------------------------
        # Define the residuals
        # ----------------------------------------------

        # Projection residuals
        for camera_key, camera in cameras.items():
            for frame_key, frame in camera['frames'].items():
                if frame_key not in frame_list:
                    continue
                for joint_key, joint in frame['joints'].items():
                    if not joint['valid']:  # skip invalid joints
                        continue

                    parameter_pattern = 'frame_' + frame_key + '_joint_' + joint_key
                    residual_key = 'projection_sensor_' + camera_key + '_frame_' + frame_key + '_joint_' + joint_key

                    params = opt.getParamsContainingPattern(
                        pattern=parameter_pattern)  # get all weight related parameters
                    opt.pushResidual(name=residual_key, params=params)
        #
        # for camera_key, camera in cameras.items():
        #     for frame_key, frame in camera['frames'].items():
        #         if int(frame_key) not in frame_list:
        #             continue
        #         for joint_key, joint in frame['joints'].items():
        #             if not joint['valid']:  # skip invalid joints
        #                 continue
        #
        #             parameter_pattern = 'frame_' + frame_key + '_joint_' + joint_key
        #             residual_key = 'projection_sensor_old_' + camera_key + '_frame_' + frame_key + '_joint_' + joint_key
        #
        #             params = opt.getParamsContainingPattern(
        #                 pattern=parameter_pattern)  # get all weight related parameters
        #             opt.pushResidual(name=residual_key, params=params)

        # Frame distance residuals
        if not args['skip_frame_to_frame_residuals']:
            # print(frame_list)
            for initial_frame_key, final_frame_key in zip(list(poses3d.keys())[:-1], list(poses3d.keys())[1:]):
                if not int(final_frame_key) - int(initial_frame_key) == 1:  # frames are not consecutive
                    continue
                if final_frame_key not in frame_list:
                    continue

                initial_pose = poses3d[initial_frame_key]

                for joint_key in initial_pose.keys():
                    residual_key = 'consecutive_frame_' + initial_frame_key + '_frame_' + final_frame_key + '_joint_' + joint_key
                    params = opt.getParamsContainingPattern(
                        pattern='frame_' + initial_frame_key + '_joint_' + joint_key + '_')  # get parameters for frame initial_frame
                    params.extend(opt.getParamsContainingPattern(
                        pattern='frame_' + final_frame_key + '_joint_' + joint_key + '_'))  # get parameters for frame final_frame
                    opt.pushResidual(name=residual_key, params=params)

        # Link length residuals
        if not args['skip_link_length_residuals']:
            for link_key, link in links.items():
                params = opt.getParamsContainingPattern(pattern='joint_' + link['parent'] + '_')
                params.extend(
                    opt.getParamsContainingPattern(pattern='joint_' + link['child'] + '_'))

                residual_key = 'link_length_' + 'joint_' + link['parent'] + '_joint_' + link['child']
                opt.pushResidual(name=residual_key, params=params)

        if args['debug']:
            opt.printResiduals()

        # ----------------------------------------------
        # Compute sparse matrix
        # ----------------------------------------------
        opt.computeSparseMatrix()
        if args['debug']:
            opt.printSparseMatrix()

        opt.startOptimization(optimization_options={'x_scale': 'jac', 'ftol': 1e-5,
                                                    'xtol': 1e-5, 'gtol': 1e-5, 'diff_step': None})  # 'max_nfev': 1}

        # # ----------------------------------------------
        # # Visualization function
        # # ----------------------------------------------

        # # Setup visualization
        if args['show_images']:
            # Draw links in images
            frame['image_gui'] = deepcopy(camera['frames'][this_frame_key]['image'])
            for _, camera in cameras.items():
                for link_name, link in links.items():
                    joint0 = frame['joints'][link['parent']]
                    joint1 = frame['joints'][link['child']]

                    if not joint0['valid'] or not joint1['valid']:
                        continue

                    x0, y0 = int(joint0['x']), int(joint0['y'])
                    x1, y1 = int(joint1['x']), int(joint1['y'])

                    cv2.line(frame['image_gui'], (x0, y0), (x1, y1), (128, 128, 128), 3)

            # Draw 2D detections on images
            for camera_key, camera in cameras.items():
                for joint_idx, joint in frame['joints'].items():
                    x, y = int(joint['x']), int(joint['y'])
                    color = joint['color']
                    if not joint['valid']:
                        print(joint_idx + " is not valid in camera " + camera_key)
                        color = (255, 0, 0)

                    square_size = 5 + (20 - 5) * joint['confidence']
                    cv2.putText(frame['image_gui'], joint_idx, (x, y), cv2.FONT_HERSHEY_PLAIN, 0.5, (0, 255, 0))
                    drawSquare2D(frame['image_gui'], x, y, square_size, color=color, thickness=3)

            # Draw projected groundtruth in images
            if args['has_ground_truth']:
                for _, camera in cameras.items():
                    frame_key = this_frame_key
                    for joint_key, joint in ground_truth[frame_key].items():

                        pts_in_world = np.ndarray((4, 1), dtype=float)
                        pts_in_world[0][0] = ground_truth[frame_key][joint_key]['pose']['x']
                        pts_in_world[1][0] = ground_truth[frame_key][joint_key]['pose']['y']
                        pts_in_world[2][0] = ground_truth[frame_key][joint_key]['pose']['z']
                        pts_in_world[3][0] = 1

                        pts_in_sensor = np.dot(camera['extrinsics'], pts_in_world)

                        # Project 3D point to camera
                        if mpi == True:
                            pts_in_image, _, _ = projectToCamera(np.array(camera['intrinsics']),
                                                                 None,
                                                                 camera['width'],
                                                                 camera['height'], pts_in_sensor[0:3, :])
                        else:
                            pts_in_image, _, _ = projectToCamera(np.array(camera['intrinsics']),
                                                                 np.array(camera['distortion']),
                                                                 camera['width'],
                                                                 camera['height'], pts_in_sensor[0:3, :])

                        xpix_projected = pts_in_image[0][0]
                        ypix_projected = pts_in_image[1][0]

                        x, y = int(xpix_projected), int(ypix_projected)
                        color = (0, 0, 0)

                        drawDiagonalCross2D(frame['image_gui'], x, y, 10, color=color, thickness=3)

            # Create visualization in images
            for camera_key, camera in cameras.items():
                frame_key = this_frame_key
                window_name = camera_key + '_' + frame_key
                # cv2.namedWindow(window_name, cv2.WINDOW_FULLSCREEN)
                cv2.namedWindow(window_name, cv2.WINDOW_NORMAL)
                cv2.imshow(window_name, frame['image_gui'])
            #
        if args['show_3d']:
            initial_frame_key= str(start_frame)

            # Setup 3D visualization
            if this_frame_key==initial_frame_key:
                plot_handles = {}
                plot_handles[initial_frame_key] = {}
                plot_handles[initial_frame_key]['figure'] = plt.figure()
                plot_handles[initial_frame_key]['figure'].suptitle('Frame #' + str(initial_frame_key), fontsize=14)
                plot_handles[initial_frame_key]['axes'] = plot_handles[initial_frame_key]['figure'].add_subplot(111, projection='3d')
                plot_handles[initial_frame_key]['axes'].set_xlim3d(-1000, 1000)
                plot_handles[initial_frame_key]['axes'].set_ylim3d(-1000, 1000)
                plot_handles[initial_frame_key]['axes'].set_zlim3d(0, 2000)
                plot_handles[initial_frame_key]['axes'].set_xlabel('x', fontsize=20)
                plot_handles[initial_frame_key]['axes'].set_ylabel('y', fontsize=20)
                plot_handles[initial_frame_key]['axes'].set_zlabel('z', fontsize=20)
                plt.setp(plot_handles[initial_frame_key]['axes'].get_xticklabels(), visible=False)
                plt.setp(plot_handles[initial_frame_key]['axes'].get_yticklabels(), visible=False)
                plt.setp(plot_handles[initial_frame_key]['axes'].get_zticklabels(), visible=False)

                figures = []
                for plot_handle_key, plot_handle in plot_handles.items():
                    figures.append(plot_handle['figure'])

                wm = WindowManager(figs=figures)
                opt.addDataModel('window_manager', wm)
            else:
                plot_handles[initial_frame_key]['axes'].clear()
                plot_handles[initial_frame_key]['figure'].suptitle('Frame #' + str(initial_frame_key), fontsize=14)
                plot_handles[initial_frame_key]['axes'] = plot_handles[initial_frame_key]['figure'].add_subplot(111, projection='3d')
                plot_handles[initial_frame_key]['axes'].set_xlim3d(-1000, 1000)
                plot_handles[initial_frame_key]['axes'].set_ylim3d(-1000, 1000)
                plot_handles[initial_frame_key]['axes'].set_zlim3d(0, 2000)
                plot_handles[initial_frame_key]['axes'].set_xlabel('x', fontsize=20)
                plot_handles[initial_frame_key]['axes'].set_ylabel('y', fontsize=20)
                plot_handles[initial_frame_key]['axes'].set_zlabel('z', fontsize=20)
                plt.setp(plot_handles[initial_frame_key]['axes'].get_xticklabels(), visible=False)
                plt.setp(plot_handles[initial_frame_key]['axes'].get_yticklabels(), visible=False)
                plt.setp(plot_handles[initial_frame_key]['axes'].get_zticklabels(), visible=False)

            # Draw floor coordinate system
            X_vec = []
            Y_vec = []
            Z_vec = []
            joint_colors = []

            for joint_key, joint in frame['joints'].items():
                point = poses3d[this_frame_key][joint_key]
                X_vec.append(point['X'])
                Y_vec.append(point['Y'])
                Z_vec.append(point['Z'])
                b, g, r = tuple(c / 255 for c in point['color'])
                joint_colors.append((r, g, b))

                # ground truth
            if args['has_ground_truth']:
                for joint_key, joint in ground_truth[this_frame_key].items():
                    point = ground_truth[this_frame_key][joint_key]['pose']
                    X_vec.append(float(point['x']))
                    Y_vec.append(float(point['y']))
                    Z_vec.append(float(point['z']))
                    b, g, r = (0, 0, 0)
                    joint_colors.append((r, g, b))

            plot_handles[initial_frame_key]['joint_handle'] = plot_handles[initial_frame_key]['axes'].scatter(X_vec, Y_vec, Z_vec,
                                                                                              c=joint_colors)  # Draw N points
            plot_handles[initial_frame_key]['coordinate_system'] = draw3Dcoordinatesystem(plot_handles[initial_frame_key]['axes'], [],
                                                                                  xc=0, yc=0, zc=0, size=0.2)

            # for frame_key, frame in selected_camera['frames'].items():
                # Draw ground_truth in 3D
            if args['has_ground_truth']:
                plot_handles[initial_frame_key]['ground_truth'] = {}
                for link_name, link in links_mpi_inf_3dhp.items():
                    plot_handles[initial_frame_key]['ground_truth'][link_name] = {}
                    joint0 = ground_truth[this_frame_key][link['parent']]['pose']
                    joint1 = ground_truth[this_frame_key][link['child']]['pose']
                    X0 = joint0['x']
                    Y0 = joint0['y']
                    Z0 = joint0['z']
                    X1 = joint1['x']
                    Y1 = joint1['y']
                    Z1 = joint1['z']

                    plot_handles[initial_frame_key]['ground_truth'][link_name]['link_handle_gt'] = plot_handles[initial_frame_key][
                        'axes'].plot([X0, X1], [Y0, Y1],
                                     [Z0, Z1],
                                     c=(0, 0, 0))

                # Plot links in 3D
            for link_name, link in links.items():
                plot_handles[initial_frame_key][link_name] = {}
                joint0 = poses3d[initial_frame_key][link['parent']]
                joint1 = poses3d[initial_frame_key][link['child']]
                X0 = joint0['X']
                Y0 = joint0['Y']
                Z0 = joint0['Z']
                X1 = joint1['X']
                Y1 = joint1['Y']
                Z1 = joint1['Z']

                plot_handles[initial_frame_key][link_name]['link_handle'] = plot_handles[initial_frame_key]['axes'].plot([X0, X1],
                                                                                                         [Y0, Y1],
                                                                                                         [Z0, Z1],
                                                                                                         c=(
                                                                                                             0.5, 0.5,
                                                                                                             0.5))
            plt.draw()
            plt.pause(0.0001)

        # cv2.waitKey(0)
        #
        def visualization_function(data):
            cameras = data['cameras']
            poses3d = data['poses3d']
            wm = data['window_manager']
            args=data['args']
            # print('Visualization function called')

            if args['show_images']:
                for camera_key, camera in cameras.items():
                    # for frame_key, frame in camera['frames'].items():
                    image = deepcopy(camera['frames'][this_frame_key]['image_gui'])
                    for joint_key, joint in frame['joints'].items():
                        x = joint['x_proj']
                        y = joint['y_proj']
                        if not joint['valid']:
                            thick = 10
                        else:
                            thick = 3
                        drawCross2D(image, x, y, 15, color=joint['color'], thickness=thick)
                        window_name = camera_key + '_' + this_frame_key
                        cv2.imshow(window_name, image)

            if args['show_3d']:
                X_vec = []
                Y_vec = []
                Z_vec = []

                min_z = 1000000
                max_z = 0

                for joint_key, joint in frame['joints'].items():
                    point = poses3d[this_frame_key][joint_key]
                    # print(point)
                    X_vec.append(point['X'])
                    Y_vec.append(point['Y'])
                    Z_vec.append(point['Z'])

                    if point['Z'] < min_z:
                        min_z = point['Z']
                    if point['Z'] > max_z:
                        max_z = point['Z']

                avg_x = sum(X_vec) / len(X_vec)
                avg_y = sum(Y_vec) / len(Y_vec)
                avg_z = sum(Z_vec) / len(Z_vec)

                plot_handles[initial_frame_key]['axes'].set_xlim3d(avg_x - 10 * avg_x, avg_x + 10 * avg_x)
                plot_handles[initial_frame_key]['axes'].set_ylim3d(avg_y - 10 * avg_y, avg_y + 10 * avg_y)
                plot_handles[initial_frame_key]['axes'].set_zlim3d(1.1 * min_z, 1.1 * max_z)

                if args['has_ground_truth']:
                    for joint_key, joint in ground_truth[frame_key].items():
                        point = ground_truth[initial_frame_key][joint_key]['pose']
                        X_vec.append(float(point['x']))
                        Y_vec.append(float(point['y']))
                        Z_vec.append(float(point['z']))

                plot_handles[initial_frame_key]['joint_handle']._offsets3d = (X_vec, Y_vec, Z_vec)
                draw3Dcoordinatesystem(plot_handles[initial_frame_key]['axes'], plot_handles[initial_frame_key]['coordinate_system'], 0, 0,
                                       min_z, 0.2, update=True)
                plt.draw()

                for link_name, link in links.items():
                    joint0 = poses3d[initial_frame_key][link['parent']]
                    joint1 = poses3d[initial_frame_key][link['child']]

                    X0 = joint0['X']
                    Y0 = joint0['Y']
                    Z0 = joint0['Z']
                    X1 = joint1['X']
                    Y1 = joint1['Y']
                    Z1 = joint1['Z']

                    plot_handles[initial_frame_key][link_name]['link_handle'][0].set_xdata([X0, X1])
                    plot_handles[initial_frame_key][link_name]['link_handle'][0].set_ydata([Y0, Y1])
                    plot_handles[initial_frame_key][link_name]['link_handle'][0].set_3d_properties(zs=[Z0, Z1])

                if args['has_ground_truth']:
                    for link_name, link in links_mpi_inf_3dhp.items():
                        joint0 = ground_truth[initial_frame_key][link['parent']]['pose']
                        joint1 = ground_truth[initial_frame_key][link['child']]['pose']
                        X0 = joint0['x']
                        Y0 = joint0['y']
                        Z0 = joint0['z']
                        X1 = joint1['x']
                        Y1 = joint1['y']
                        Z1 = joint1['z']

                        plot_handles[initial_frame_key]['ground_truth'][link_name]['link_handle_gt'][0].set_xdata([X0, X1])
                        plot_handles[initial_frame_key]['ground_truth'][link_name]['link_handle_gt'][0].set_ydata([Y0, Y1])
                        plot_handles[initial_frame_key]['ground_truth'][link_name]['link_handle_gt'][0].set_3d_properties(
                            zs=[Z0, Z1])

                    plt.draw()

            wm.waitForKey(time_to_wait=0.001, verbose=False, message='')

        if args['show_images'] or args['show_3d']:
            opt.setVisualizationFunction(visualization_function, always_visualize=True, niterations=10)
            opt.internal_visualization = False



if __name__ == "__main__":
    main()
