#!/usr/bin/env python3
import sys
from pathlib import Path
import os

from pyglet.media.drivers.pulse.lib_pulseaudio import PA_CHANNEL_POSITION_AUX29

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
                + '/scripts/')

import rospy
from sensor_msgs.msg import CameraInfo, Image, CompressedImage
from visualization_msgs.msg import Marker
from geometry_msgs.msg import Point, TransformStamped
from tf2_msgs.msg import TFMessage
from cv_bridge import CvBridge

import cv2
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

from hpe.msg import skeleton, person2D, keypoint2D, person3D, keypoint3D
import functools
import operator
from functools import partial

from utils.links import joints_mediapipe, links_mediapipe,  links_mpi_inf_3dhp, joints_mpi_inf_3dhp
from utils.transforms import getTransform, getChain, getAggregateTransform
from urdf_parser_py.urdf import URDF
from atom_core.naming import generateKey
from tf.transformations import quaternion_from_matrix, quaternion_from_euler
from optimization.objective_function_ros import objectiveFunction
from utils.load_cam_params import getCamParamsMPI
from utils.draw import drawSquare2D, drawCross2D, drawCircle, drawDiagonalCross2D, draw3Dcoordinatesystem


if os.environ.get('USER') == 'mike':
    from OptimizationUtils import Optimizer
    from KeyPressManager import WindowManager
else:
    import OptimizationUtils.OptimizationUtils as OptimizationUtils
    from OptimizationUtils.OptimizationUtils import Optimizer
    from OptimizationUtils.KeyPressManager import WindowManager


class HPE():
    def __init__(self, topics):

        self.sll = True
        self.sff = True
        self.mpi = False

        # Publishers - 3D skeleton keypoint_list and markers to draw skeleton
        # x = str(topic).split('/')
        self.publish_skeleton = rospy.Publisher("/3D_skeleton", person3D, queue_size=2)
        self.person3d_msg = person3D()
        self.marker_pub = rospy.Publisher("/3D_skeleton_vis", Marker, queue_size=10)
        self.person3d_msg.header.frame_id = "world"
        self.person3d_msg.header.stamp = rospy.Time.now()
        self.selected_camera_topic = ''

        self.lock_frame = False
        self.current_frame = 0

        self.poses3d = {}

        self.keypoint_list = {}
        self.subscribers = {}

        self.received = {}
        self.cameras = {}
        self.topics = []

        self.images_0={}


        for topic in topics:
            topic = str(topic).split('/')[0]
            print(topic)
            self.topics.append(topic)

            self.received[topic] = False

            self.cameras[topic] = {}
            self.cameras[topic]['frames'] = {}
            if self.selected_camera_topic == '':
                self.selected_camera_topic = topic

            self.keypoint_list[topic] = np.zeros((33, 3))
            # rospy.Subscriber(topic+'/image_raw', Image, self.callback_img_2d, topic)

            #wait for image message
            # print(topic+'/rgb/image_raw')
            # msg=rospy.wait_for_message(topic+'/rgb/image_raw', Image)
            #
            # bridge = CvBridge()
            # np_arr = bridge.imgmsg_to_cv2(msg, desired_encoding="bgr8")
            # # print(f"Image encoding: {msg.encoding}")
            # # print(f"Image size: {msg.width}x{msg.height}")
            # # print(f"Data length: {len(msg.data)} bytes")
            # # exit(0)
            # # np_arr = np.frombuffer(msg.data, np.uint8)
            # # print(np_arr.shape)
            # # np_arr= np_arr.reshape((2048,2048,3))
            # # print(np_arr.shape)
            #
            # # print(np_arr)
            # # cv2.imshow(topic, np_arr)
            # #
            # # image_raw = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
            #
            # # print(image_raw)
            # self.images_0[topic] = np_arr


        self.topic_frame_idxs = {topic: 0 for topic in self.topics}

        for topic in topics:
            self.subscribers[topic] = rospy.Subscriber(topic, person2D, self.callback, str(topic).split('/')[0])

        self.transform_tree_dict = self.get_transform_tree_dict(
            os.path.dirname(os.path.abspath(__file__)) + '/../xacro/larcc.urdf')

        self.get_camera_extrinsics()
        self.get_camera_intrinsics()

        self.args = {
            'has_ground_truth': False,
            'dataset_name': "lar",
            '2d_detector': "mediapipe",
            'skip_frame_to_frame_residuals': self.sff,
            'skip_link_length_residuals': self.sll,
            'debug': False
        }

        self.first_frame=-1

        # self.camera_detected={}
        # for topic in self.topics:
        #     self.camera_detected[topic] = False
    #
    # def callback_img_2d(self, msg, topic):
    #     np_arr = np.fromstring(msg.data, np.uint8)
    #     image_raw = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
    #     if self.images_0[topic]:
    #         self.images_0[topic] = image_raw
    #     return
    #

    def callback(self, msg, topic):

        print("Received message from " + topic)
        # print(msg)
        # print(topic)
        joints = {}
        if msg.keypoints:
            idx = 0
            n_valid = 0
            for kpt in msg.keypoints:
                self.keypoint_list[topic][idx][0] = kpt.x
                self.keypoint_list[topic][idx][1] = kpt.y
                self.keypoint_list[topic][idx][2] = kpt.score

                x, y = kpt.x, kpt.y
                confidence = kpt.score
                valid = (not x == 0) and (not y == 0)
                if confidence < 0.25:
                    valid = False
                joints[idx] = {'x': x, 'y': y, 'confidence': confidence, 'valid': valid,
                               'x_proj': 0.0, 'y_proj': 0.0}
                if valid:
                    n_valid += 1
                idx += 1

            if n_valid > 15:
                self.cameras[topic]['frames'][str(self.topic_frame_idxs[topic])] = {'joints': joints,
                                                                                    'valid_frame': True}
            else:
                self.cameras[topic]['frames'][str(self.topic_frame_idxs[topic])] = {'joints': joints,
                                                                                    'valid_frame': False}

        else:
            print("Camera " + str(topic) + " has no keypoints.")
            self.cameras[topic]['frames'][str(self.topic_frame_idxs[topic])] = {'joints': joints, 'valid_frame': False}
        # print(str(topic) + " is in frame " + str(self.topic_frame_idxs[topic]) + ".")

        # print("Topic " + topic)

        # if self.topic_frame_idxs[topic]==0:
        #     print("Topic " + topic + " is in frame 0.")
        #     self.show_2d_skeleton( self.cameras[topic]['frames'][str(self.topic_frame_idxs[topic])], topic)


        self.topic_frame_idxs[topic] += 1

    def show_2d_skeleton(self, keypoints_2d, topic):
        # print("Keys:")
        # print(self.images_0.keys())
        # print("Topic:" + topic)
        # print(topic)
        # print(self.images_0[topic])
        new_image=self.images_0[topic]
        print(new_image.shape)
        # exit(0)
        cv2.imshow("Blabla", new_image)
        # print(new_image)
        # cv2.imshow(str(topic), new_image)
        # cv2.waitKey(1)
        # print("Keypoints:")
        # print(keypoints_2d)


        # exit(0)
        for joint in keypoints_2d['joints'].values():
            print("Joints: ")
            print(joint)
            x, y = int(joint['x']), int(joint['y'])

            square_size = 5 + (20 - 5) * joint['confidence']
            # cv2.putText(image, joint_idx, (x, y), cv2.FONT_HERSHEY_PLAIN, 0.5, (0, 255, 0))
            drawSquare2D(new_image, x, y, square_size, color=(0,0,0), thickness=3)

        for link_name, link in links_mediapipe.items():
            # print("Link: ")
            # print(joints_mpi_inf_3dhp[link['parent']]['idx'])
            joint0 = keypoints_2d['joints'][joints_mediapipe[link['parent']]['idx']]
            joint1 = keypoints_2d['joints'][joints_mediapipe[link['child']]['idx']]

            if not joint0['valid'] or not joint1['valid']:
                continue

            x0, y0 = int(joint0['x']), int(joint0['y'])
            x1, y1 = int(joint1['x']), int(joint1['y'])

            cv2.line(new_image, (x0, y0), (x1, y1), (128, 128, 128), 3)
        cv2.imshow(topic, new_image)
        return

    def get_transform_tree_dict(self, file):
        xml_robot = URDF.from_xml_file(file)
        dict = {}

        for joint in xml_robot.joints:
            child = joint.child
            parent = joint.parent
            xyz = joint.origin.xyz
            rpy = joint.origin.rpy
            key = generateKey(parent, child)

            dict[key] = {}
            dict[key]['child'] = child
            dict[key]['parent'] = parent
            dict[key]['trans'] = xyz
            dict[key]['quat'] = list(quaternion_from_euler(rpy[0], rpy[1], rpy[2], axes='sxyz'))

        return dict

    def get_camera_extrinsics(self):
        if self.mpi == True:
            cam_params = getCamParamsMPI('S1', 'Seq1', ['camera_0', 'camera_4', 'camera_5', 'camera_8'])
            for topic_idx, camera in self.cameras.items():
                self.cameras[topic_idx]['extrinsics'] = cam_params[topic_idx]['T']
        else:
            for topic_idx, camera in self.cameras.items():
                self.cameras[topic_idx]['extrinsics'] = getTransform('world', topic_idx + '_link',
                                                                     self.transform_tree_dict)
        return

    def get_camera_intrinsics(self):
        msg = CameraInfo()
        for topic_idx, camera in self.cameras.items():
            msg = rospy.wait_for_message('/' + topic_idx + '/rgb/camera_info', CameraInfo, timeout=None)

            self.cameras[topic_idx]['intrinsics'] = np.reshape(msg.K, (3, 3))
            self.cameras[topic_idx]['distortion'] = np.reshape(msg.D, (5, 1))

            self.cameras[topic_idx]['height'] = msg.height
            self.cameras[topic_idx]['width'] = msg.width

        return

    def hpe_3d(self):
        if self.first_frame==-1:
            self.first_frame=self.current_frame
        joints_to_use = []

        if self.args['2d_detector']=="mediapipe":
            for _, link in links_mediapipe.items():  # derive al joints to use based on the parent and childs
                joints_to_use.append(link['parent'])
                joints_to_use.append(link['child'])
            joints_to_use = list(set(joints_to_use))  # remove repetitions
        # else:
        #     for _, link in links_mpi_inf_3dhp.items():  # derive al joints to use based on the parent and childs
        #         joints_to_use.append(link['parent'])
        #         joints_to_use.append(link['child'])
        #     joints_to_use = list(set(joints_to_use))  # remove repetitions

        # print(joints_to_use)
        # exit(0)

        self.poses3d[self.current_frame] = {}
        if self.current_frame==self.first_frame:
            for joint_key in joints_to_use:
                self.poses3d[self.current_frame][joint_key] = {'X': 0.0, 'Y': 0.0, 'Z': 0.0}
        else:
            for joint_key in joints_to_use:
                self.poses3d[self.current_frame][joint_key] = self.poses3d[self.current_frame-1][joint_key]

        # calculate number of valid frames:
        n_valid_frames = 0
        for topic_idx, camera in self.cameras.items():
            if str(self.current_frame) in self.cameras[topic_idx]['frames']:
                if self.cameras[topic_idx]['frames'][str(self.current_frame)]['valid_frame']:
                    n_valid_frames += 1
                    print(topic_idx)
                    print(self.cameras[topic_idx]['frames'][str(self.current_frame)]['joints'])
        # print("Frame " + str(self.current_frame) + " has " + str(n_valid_frames) + " valid cameras.")

        if n_valid_frames > 2:
            print("OPTIMIZING WITH " + str(n_valid_frames) + " VALID CAMERAS FOR CURRENT FRAME " + str(self.current_frame))

            opt = Optimizer()
            opt.addDataModel('args', self.args)
            opt.addDataModel('cameras', self.cameras)
            opt.addDataModel('poses3d', self.poses3d)
            opt.addDataModel('joints_to_use', joints_to_use)
            errors_per_iteration = {}
            opt.addDataModel('errors', errors_per_iteration)
            # opt.addDataModel('current_frame',)

            frames_to_use = []
            frames_to_use.append(self.current_frame)

            opt.addDataModel('frames_to_use', frames_to_use)

            def getJointXYZ(data, frame_key, joint_key):
                # print(frame_key, joint_key)
                d = data[frame_key][joint_key]
                X, Y, Z = d['X'], d['Y'], d['Z']
                return [X, Y, Z]

            def setJointXYZ(data, values, frame_key, joint_key):
                X, Y, Z = values[0], values[1], values[2]
                d = data[frame_key][joint_key]
                d['X'] = X
                d['Y'] = Y
                d['Z'] = Z

            for joint_key in joints_to_use:
                group_name = 'frame_' + str(self.current_frame) + '_joint_' + str(joint_key)
                opt.pushParamVector(group_name, data_key='poses3d',
                                    getter=partial(getJointXYZ, frame_key=self.current_frame, joint_key=joint_key),
                                    setter=partial(setJointXYZ, frame_key=self.current_frame, joint_key=joint_key),
                                    suffix=['_X', '_Y', '_Z'])

            # opt.printParameters()
            # ----------------------------------------------
            # Define the objective function
            # ----------------------------------------------
            opt.setObjectiveFunction(objectiveFunction)

            for camera_key, camera in self.cameras.items():
                frame = camera['frames'][str(self.current_frame)]

                # self.show_2d_skeleton(camera_key, frame)

                # print(frame)
                for joint_key in joints_to_use:
                    # if self.mpi==False:
                    joint_idx = joints_mediapipe[joint_key]['idx']
                    # else:
                    #     joint_idx = joints_mpi_inf_3dhp[joint_key]['idx']
                    joint = frame['joints'][joint_idx]

                    if joint['valid']:  # skip invalid joints
                        parameter_pattern = 'frame_' + str(self.current_frame) + '_joint_' + joint_key
                        residual_key = 'projection_sensor_' + camera_key + '_frame_' + str(
                            self.current_frame) + '_joint_' + joint_key

                        params = opt.getParamsContainingPattern(
                            pattern=parameter_pattern)  # get all weight related parameters
                        opt.pushResidual(name=residual_key, params=params)
                    else:
                        #     print("camera " + str(camera_key) + ": Joint " + str(
                        #         joint_key) + ' is ' + 'invalid. (residuals)')
                        continue

            # Frame distance residuals
            if not self.sff:
                for initial_frame_key, final_frame_key in zip(list(self.poses3d.keys())[:-1],
                                                              list(self.poses3d.keys())[1:]):
                    if not int(final_frame_key) - int(initial_frame_key) == 1:  # frames are not consecutive
                        # print('Frame ' + initial_frame_key + ' and ' + final_frame_key + ' are not consecutive.')
                        continue

                    initial_pose = self.poses3d[initial_frame_key]
                    # final_pose = poses3d[final_frame_key]

                    for joint_key in initial_pose.keys():
                        residual_key = 'consecutive_frame_' + initial_frame_key + '_frame_' + final_frame_key + '_joint_' + joint_key

                        params = opt.getParamsContainingPattern(
                            pattern='frame_' + initial_frame_key + '_joint_' + joint_key + '_')  # get parameters for frame initial_frame
                        params.extend(opt.getParamsContainingPattern(
                            pattern='frame_' + final_frame_key + '_joint_' + joint_key + '_'))  # get parameters for frame final_frame
                        opt.pushResidual(name=residual_key, params=params)

            # Link length residuals
            if not self.sll:
                for link_key, link in links_mediapipe.items():
                    print('Link ' + link['parent'] + '-' + link['child'])

                    for frame_key, frame in self.poses3d.items():  # compute residual as distance from reference link length

                        params = opt.getParamsContainingPattern(
                            pattern='frame_' + frame_key + '_joint_' + link['parent'] + '_')
                        params.extend(
                            opt.getParamsContainingPattern(
                                pattern='frame_' + frame_key + '_joint_' + link['child'] + '_'))

                        residual_key = 'length_frame_' + frame_key + '_joint_' + link['parent'] + '_joint_' + link[
                            'child']
                        opt.pushResidual(name=residual_key, params=params)

            opt.computeSparseMatrix()
            opt.printSparseMatrix()
            opt.startOptimization(optimization_options={'x_scale': 'jac', 'ftol': 1e-7,
                                                        'xtol': 1e-7, 'gtol': 1e-7,
                                                        'diff_step': None})  # , 'max_nfev': 1

            # self.draw_3d_skeleton_debug(self.poses3d[self.current_frame])

            # if self.current_frame==0:
            #     self.show_2d_skeleton()


            self.draw_3D_skeleton(self.poses3d[self.current_frame])

        self.current_frame += 1
        return

    def draw_3d_skeleton_debug(self, keypoints3d):
        fig = plt.figure()
        fig.suptitle('Frame #' + str(self.current_frame), fontsize=14)
        ax = fig.add_subplot(111, projection='3d')
        ax.set_xlim3d(-1000, 1000)
        ax.set_ylim3d(-1000, 1000)
        ax.set_zlim3d(0, 2000)
        ax.set_xlabel('x', fontsize=20)
        ax.set_ylabel('y', fontsize=20)
        ax.set_zlabel('z', fontsize=20)
        plt.setp(ax.get_xticklabels(), visible=False)
        plt.setp(ax.get_yticklabels(), visible=False)
        plt.setp(ax.get_zticklabels(), visible=False)

        X_vec = []
        Y_vec = []
        Z_vec = []
        joint_colors = []

        for joint_key, joint in joints_mediapipe.items():
            point = keypoints3d[joint_key]
            X_vec.append(point['X'])
            Y_vec.append(point['Y'])
            Z_vec.append(point['Z'])
            b, g, r = (0,0,0)
            joint_colors.append((r, g, b))

        for link_name, link in links_mediapipe.items():
            joint0 = keypoints3d[link['parent']]
            joint1 = keypoints3d[link['child']]
            X0 = joint0['X']
            Y0 = joint0['Y']
            Z0 = joint0['Z']
            X1 = joint1['X']
            Y1 = joint1['Y']
            Z1 = joint1['Z']

            ax.plot(
                [X0, X1],
                [Y0, Y1],
                [Z0, Z1],
                c=(
                    0.5, 0.5,
                    0.5))

        ax.scatter(X_vec, Y_vec, Z_vec,c=joint_colors)  # Draw N points

        plt.waitforbuttonpress()
        return

    def draw_3D_skeleton(self, keypoints3d):
        marker = Marker()

        marker.type = marker.LINE_LIST
        marker.action = marker.ADD
        marker.header.frame_id = 'world'
        # marker scale
        marker.scale.x = 0.03

        # marker color
        marker.color.a = 1.0
        marker.color.r = 0.0
        marker.color.g = 1.0
        marker.color.b = 0.0

        # marker orientation
        marker.pose.orientation.x = 0.0
        marker.pose.orientation.y = 0.0
        marker.pose.orientation.z = 0.0
        marker.pose.orientation.w = 1.0

        # marker position
        marker.pose.position.x = 0.0
        marker.pose.position.y = 0.0
        marker.pose.position.z = 0.0

        # P0.x = keypoints3d[0][0]
        # P0.y = keypoints3d[0][1]
        # P0.z = keypoints3d[0][2]
        #
        # P11.x = keypoints3d[11][0]
        # P11.y = keypoints3d[11][1]
        # P11.z = keypoints3d[11][2]
        #
        # P12.x = keypoints3d[12][0]
        # P12.y = keypoints3d[12][1]
        # P12.z = keypoints3d[12][2]
        #
        # P13.x = keypoints3d[13][0]
        # P13.y = keypoints3d[13][1]
        # P13.z = keypoints3d[13][2]
        #
        # P14.x = keypoints3d[14][0]
        # P14.y = keypoints3d[14][1]
        # P14.z = keypoints3d[14][2]
        #
        # P16.x = keypoints3d[16][0]
        # P16.y = keypoints3d[16][1]
        # P16.z = keypoints3d[16][2]
        #
        # P23.x = keypoints3d[23][0]
        # P23.y = keypoints3d[23][1]
        # P23.z = keypoints3d[23][2]
        #
        # P24.x = keypoints3d[24][0]
        # P24.y = keypoints3d[24][1]
        # P24.z = keypoints3d[24][2]
        #
        # P25.x = keypoints3d[25][0]
        # P25.y = keypoints3d[25][1]
        # P25.z = keypoints3d[25][2]
        #
        # P26.x = keypoints3d[26][0]
        # P26.y = keypoints3d[26][1]
        # P26.z = keypoints3d[26][2]
        #
        # P27.x = keypoints3d[27][0]
        # P27.y = keypoints3d[27][1]
        # P27.z = keypoints3d[27][2]
        #
        # P28.x = keypoints3d[28][0]
        # P28.y = keypoints3d[28][1]
        # P28.z = keypoints3d[28][2]
        #
        # P29.x = keypoints3d[29][0]
        # P29.y = keypoints3d[29][1]
        # P29.z = keypoints3d[29][2]
        #
        # P30.x = keypoints3d[30][0]
        # P30.y = keypoints3d[30][1]
        # P30.z = keypoints3d[30][2]
        #
        # P31.x = keypoints3d[31][0]
        # P31.y = keypoints3d[31][1]
        # P31.z = keypoints3d[31][2]
        #
        # P32.x = keypoints3d[32][0]
        # P32.y = keypoints3d[32][1]
        # P32.z = keypoints3d[32][2]
        if self.args['2d_detector']=="mpi":

            P0 = Point()
            P1 = Point()
            P2 = Point()
            P4 = Point()
            P5 = Point()
            P6 = Point()
            P7 = Point()
            P9 = Point()
            P10 = Point()
            P11 = Point()
            P12 = Point()
            P14 = Point()
            P15 = Point()
            P16 = Point()
            P17 = Point()
            P18 = Point()
            P19 = Point()
            P20 = Point()
            P21 = Point()
            P23 = Point()
            P24 = Point()
            P25 = Point()
            P26 = Point()

            P0.x = keypoints3d['Spine3']['X']/1000
            P0.y = keypoints3d['Spine3']['Y']/1000
            P0.z = keypoints3d['Spine3']['Z']/1000

            P1.x = keypoints3d['Spine4']['X']/1000
            P1.y = keypoints3d['Spine4']['Y']/1000
            P1.z = keypoints3d['Spine4']['Z']/1000

            P2.x = keypoints3d['Spine2']['X']/1000
            P2.y = keypoints3d['Spine2']['Y']/1000
            P2.z = keypoints3d['Spine2']['Z']/1000

            P4.x = keypoints3d['Pelvis']['X']/1000
            P4.y = keypoints3d['Pelvis']['Y']/1000
            P4.z = keypoints3d['Pelvis']['Z']/1000

            P5.x = keypoints3d['Neck']['X']/1000
            P5.y = keypoints3d['Neck']['Y']/1000
            P5.z = keypoints3d['Neck']['Z']/1000

            P6.x = keypoints3d['Head']['X']/1000
            P6.y = keypoints3d['Head']['Y']/1000
            P6.z = keypoints3d['Head']['Z']/1000

            P7.x = keypoints3d['HeadTop']['X']/1000
            P7.y = keypoints3d['HeadTop']['Y']/1000
            P7.z = keypoints3d['HeadTop']['Z']/1000

            P9.x = keypoints3d['LShoulder']['X']/1000
            P9.y = keypoints3d['LShoulder']['Y']/1000
            P9.z = keypoints3d['LShoulder']['Z']/1000

            P10.x = keypoints3d['LElbow']['X']/1000
            P10.y = keypoints3d['LElbow']['Y']/1000
            P10.z = keypoints3d['LElbow']['Z']/1000

            P11.x = keypoints3d['LWrist']['X']/1000
            P11.y = keypoints3d['LWrist']['Y']/1000
            P11.z = keypoints3d['LWrist']['Z']/1000

            P12.x = keypoints3d['LHand']['X']/1000
            P12.y = keypoints3d['LHand']['Y']/1000
            P12.z = keypoints3d['LHand']['Z']/1000

            P14.x = keypoints3d['RShoulder']['X']/1000
            P14.y = keypoints3d['RShoulder']['Y']/1000
            P14.z = keypoints3d['RShoulder']['Z']/1000

            P15.x = keypoints3d['RElbow']['X']/1000
            P15.y = keypoints3d['RElbow']['Y']/1000
            P15.z = keypoints3d['RElbow']['Z']/1000

            P16.x = keypoints3d['RWrist']['X']/1000
            P16.y = keypoints3d['RWrist']['Y']/1000
            P16.z = keypoints3d['RWrist']['Z']/1000

            P17.x = keypoints3d['RHand']['X']/1000
            P17.y = keypoints3d['RHand']['Y']/1000
            P17.z = keypoints3d['RHand']['Z']/1000

            P18.x = keypoints3d['LHip']['X']/1000
            P18.y = keypoints3d['LHip']['Y']/1000
            P18.z = keypoints3d['LHip']['Z']/1000

            P19.x = keypoints3d['LKnee']['X']/1000
            P19.y = keypoints3d['LKnee']['Y']/1000
            P19.z = keypoints3d['LKnee']['Z']/1000

            P20.x = keypoints3d['LAnkle']['X']/1000
            P20.y = keypoints3d['LAnkle']['Y']/1000
            P20.z = keypoints3d['LAnkle']['Z']/1000

            P21.x = keypoints3d['LFoot']['X']/1000
            P21.y = keypoints3d['LFoot']['Y']/1000
            P21.z = keypoints3d['LFoot']['Z']/1000

            P23.x = keypoints3d['RHip']['X']/1000
            P23.y = keypoints3d['RHip']['Y']/1000
            P23.z = keypoints3d['RHip']['Z']/1000

            P24.x = keypoints3d['RKnee']['X']/1000
            P24.y = keypoints3d['RKnee']['Y']/1000
            P24.z = keypoints3d['RKnee']['Z']/1000

            P25.x = keypoints3d['RAnkle']['X']/1000
            P25.y = keypoints3d['RAnkle']['Y']/1000
            P25.z = keypoints3d['RAnkle']['Z']/1000

            P26.x = keypoints3d['RFoot']['X']/1000
            P26.y = keypoints3d['RFoot']['Y']/1000
            P26.z = keypoints3d['RFoot']['Z']/1000

        elif self.args['2d_detector']=="mediapipe":
            P0 = Point()
            P11 = Point()
            P12 = Point()
            P13 = Point()
            P14 = Point()
            P15 = Point()
            P16 = Point()
            P23 = Point()
            P24 = Point()
            P25 = Point()
            P26 = Point()
            P27 = Point()
            P28 = Point()
            P29 = Point()
            P30 = Point()
            P31 = Point()
            P32 = Point()

            P0.x = keypoints3d['Nose']['Z']
            P0.y = keypoints3d['Nose']['X']
            P0.z = keypoints3d['Nose']['Y']

            P11.x = keypoints3d['LShoulder']['Z']
            P11.y = keypoints3d['LShoulder']['X']
            P11.z = keypoints3d['LShoulder']['Y']

            P13.x = keypoints3d['LElbow']['Z']
            P13.y = keypoints3d['LElbow']['X']
            P13.z = keypoints3d['LElbow']['Y']

            P15.x = keypoints3d['LWrist']['Z']
            P15.y = keypoints3d['LWrist']['X']
            P15.z = keypoints3d['LWrist']['Y']

            P12.x = keypoints3d['RShoulder']['Z']
            P12.y = keypoints3d['RShoulder']['X']
            P12.z = keypoints3d['RShoulder']['Y']

            P14.x = keypoints3d['RElbow']['Z']
            P14.y = keypoints3d['RElbow']['X']
            P14.z = keypoints3d['RElbow']['Y']

            P16.x = keypoints3d['RWrist']['Z']
            P16.y = keypoints3d['RWrist']['X']
            P16.z = keypoints3d['RWrist']['Y']

            P23.x = keypoints3d['LHip']['Z']
            P23.y = keypoints3d['LHip']['X']
            P23.z = keypoints3d['LHip']['Y']

            P25.x = keypoints3d['LKnee']['Z']
            P25.y = keypoints3d['LKnee']['X']
            P25.z = keypoints3d['LKnee']['Y']

            P27.x = keypoints3d['LAnkle']['Z']
            P27.y = keypoints3d['LAnkle']['X']
            P27.z = keypoints3d['LAnkle']['Y']

            P24.x = keypoints3d['RHip']['Z']
            P24.y = keypoints3d['RHip']['X']
            P24.z = keypoints3d['RHip']['Y']

            P26.x = keypoints3d['RKnee']['Z']
            P26.y = keypoints3d['RKnee']['X']
            P26.z = keypoints3d['RKnee']['Y']

            P28.x = keypoints3d['RAnkle']['Z']
            P28.y = keypoints3d['RAnkle']['X']
            P28.z = keypoints3d['RAnkle']['Y']

            P31.x = keypoints3d['LBigToe']['Z']
            P31.y = keypoints3d['LBigToe']['X']
            P31.z = keypoints3d['LBigToe']['Y']

            P29.x = keypoints3d['LHeel']['Z']
            P29.y = keypoints3d['LHeel']['X']
            P29.z = keypoints3d['LHeel']['Y']

            P32.x = keypoints3d['RBigToe']['Z']
            P32.y = keypoints3d['RBigToe']['X']
            P32.z = keypoints3d['RBigToe']['Y']

            P30.x = keypoints3d['RHeel']['Z']
            P30.y = keypoints3d['RHeel']['X']
            P30.z = keypoints3d['RHeel']['Y']

        marker.points = []

        if self.args['2d_detector']=="mpi":
            marker.points.append(P6)
            marker.points.append(P7)

            marker.points.append(P6)
            marker.points.append(P5)

            marker.points.append(P5)
            marker.points.append(P1)

            marker.points.append(P1)
            marker.points.append(P0)

            marker.points.append(P0)
            marker.points.append(P2)

            marker.points.append(P2)
            marker.points.append(P4)

            marker.points.append(P14)
            marker.points.append(P15)

            marker.points.append(P15)
            marker.points.append(P16)

            marker.points.append(P16)
            marker.points.append(P17)

            marker.points.append(P1)
            marker.points.append(P14)

            marker.points.append(P9)
            marker.points.append(P10)

            marker.points.append(P10)
            marker.points.append(P11)

            marker.points.append(P11)
            marker.points.append(P12)

            marker.points.append(P1)
            marker.points.append(P9)

            marker.points.append(P4)
            marker.points.append(P18)

            marker.points.append(P18)
            marker.points.append(P19)

            marker.points.append(P19)
            marker.points.append(P20)

            marker.points.append(P20)
            marker.points.append(P21)

            marker.points.append(P4)
            marker.points.append(P23)

            marker.points.append(P23)
            marker.points.append(P24)

            marker.points.append(P24)
            marker.points.append(P25)

            marker.points.append(P25)
            marker.points.append(P26)
        else:
            marker.points.append(P0)
            marker.points.append(P0)

            marker.points.append(P12)  # RShoulder
            marker.points.append(P14)  # RElbow

            marker.points.append(P14)  # RElbow
            marker.points.append(P16)  # RWrist

            marker.points.append(P11)  # LShoulder
            marker.points.append(P13)  # LElbow

            marker.points.append(P13)  # LElbow
            marker.points.append(P15)  # LWrist

            marker.points.append(P12)  # RShoulder
            marker.points.append(P11)  # LShoulder

            marker.points.append(P24)  # RHip
            marker.points.append(P23)  # LHip

            marker.points.append(P12)  # RShoulder
            marker.points.append(P24)  # RHip

            marker.points.append(P23)  # LHip
            marker.points.append(P11)  # LShoulder

            marker.points.append(P24)  # RHip
            marker.points.append(P26)  # RKnee

            marker.points.append(P26)  # RKnee
            marker.points.append(P28)  # RAnkle

            marker.points.append(P23)  # LHip
            marker.points.append(P25)  # LKnee

            marker.points.append(P25)  # LKnee
            marker.points.append(P27)  # LAnkle

            marker.points.append(P28)  # RAnkle
            marker.points.append(P30)  # RHeel

            marker.points.append(P28)  # RAnkle
            marker.points.append(P32)  # RBigToe

            marker.points.append(P32)  # RBigToe
            marker.points.append(P28)  # RAnkle

            marker.points.append(P27)  # LAnkle
            marker.points.append(P29)  # LHeel

            marker.points.append(P27)  # LAnkle
            marker.points.append(P31)  # LBigToe

            marker.points.append(P31)  # LBigToe
            marker.points.append(P27)  # LAnkle


        # while not rospy.is_shutdown():
        self.marker_pub.publish(marker)

        return


if __name__ == '__main__':
    # get topics to do hpe
    topics = sys.argv[1:-2]

    # start python wrapper
    rospy.init_node('hpe_skeleton')
    rate = rospy.Rate(3)
    rospy.loginfo("Publishing 3D human skeleton")
    rospy.loginfo("Topics: " + str(topics))
    node = HPE(topics)
    # cameras.append(node)
    while not rospy.is_shutdown():
        node.hpe_3d()
        rate.sleep()
