#!/usr/bin/env python3
import sys
from pathlib import Path
import os

from pyglet.media.drivers.pulse.lib_pulseaudio import PA_CHANNEL_POSITION_AUX29

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
                + '/scripts/')

import rospy
from sensor_msgs.msg import CameraInfo, Image, CompressedImage
from visualization_msgs.msg import Marker
from geometry_msgs.msg import Point, TransformStamped
from tf2_msgs.msg import TFMessage
from cv_bridge import CvBridge

import cv2
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

from hpe.msg import skeleton, person2D, keypoint2D, person3D, keypoint3D
import functools
import operator
from functools import partial

from utils.links import joints_mediapipe, links_mediapipe,  links_mpi_inf_3dhp, joints_mpi_inf_3dhp
from utils.transforms import getTransform, getChain, getAggregateTransform
from urdf_parser_py.urdf import URDF
from atom_core.naming import generateKey
from tf.transformations import quaternion_from_matrix, quaternion_from_euler
from optimization.objective_function_ros import objectiveFunction
from utils.load_cam_params import getCamParamsMPI

if os.environ.get('USER') == 'mike':
    from OptimizationUtils import Optimizer
    from KeyPressManager import WindowManager
else:
    import OptimizationUtils.OptimizationUtils as OptimizationUtils
    from OptimizationUtils.OptimizationUtils import Optimizer
    from OptimizationUtils.KeyPressManager import WindowManager


class HPE():
    def __init__(self, topics):
        self.sll = True
        self.sff = True
        self.mpi = True

        # Publishers - 3D skeleton keypoint_list and markers to draw skeleton
        # x = str(topic).split('/')
        self.publish_skeleton = rospy.Publisher("/3D_skeleton", person3D, queue_size=2)
        self.person3d_msg = person3D()
        self.marker_pub = rospy.Publisher("/3D_skeleton_vis", Marker, queue_size=10)
        self.person3d_msg.header.frame_id = "world"
        self.person3d_msg.header.stamp = rospy.Time.now()
        self.selected_camera_topic = ''

        self.lock_frame = False
        self.current_frame = 0

        self.keypoint_list = {}
        self.subscribers = {}

        self.received = {}
        self.cameras = {}
        self.topics = []

        for topic in topics:
            topic = str(topic).split('/')[0]
            self.topics.append(topic)

            self.received[topic] = False

            self.cameras[topic] = {}
            self.cameras[topic]['frames'] = {}
            if self.selected_camera_topic == '':
                self.selected_camera_topic = topic

            self.keypoint_list[topic] = np.zeros((33, 3))

        self.topic_frame_idxs = {topic: 0 for topic in self.topics}

        for topic in topics:
            self.subscribers[topic] = rospy.Subscriber(topic, person2D, self.callback, str(topic).split('/')[0])




    def callback(self, msg, topic):
        # joints = {}
        # if msg.keypoints:
        #     idx = 0
        #     n_valid = 0
        #     for kpt in msg.keypoints:
        #         self.keypoint_list[topic][idx][0] = kpt.x
        #         self.keypoint_list[topic][idx][1] = kpt.y
        #         self.keypoint_list[topic][idx][2] = kpt.score
        #
        #         x, y = kpt.x, kpt.y
        #         confidence = kpt.score
        #         valid = (not x == 0) and (not y == 0)
        #         if confidence < 0.25:
        #             valid = False
        #         joints[idx] = {'x': x, 'y': y, 'confidence': confidence, 'valid': valid,
        #                        'x_proj': 0.0, 'y_proj': 0.0}
        #         if valid:
        #             n_valid += 1
        #         idx += 1
        #
        #     if n_valid > 8:
        #         self.cameras[topic]['frames'][str(self.topic_frame_idxs[topic])] = {'joints': joints,
        #                                                                             'valid_frame': True}
        #     else:
        #         self.cameras[topic]['frames'][str(self.topic_frame_idxs[topic])] = {'joints': joints,
        #                                                                             'valid_frame': False}
        #
        # else:
        #     print("Camera " + str(topic) + " has no keypoints.")
        #     self.cameras[topic]['frames'][str(self.topic_frame_idxs[topic])] = {'joints': joints, 'valid_frame': False}
        print(str(topic) + " is in frame " + str(self.topic_frame_idxs[topic]) + ".")
        self.topic_frame_idxs[topic] += 1


if __name__ == '__main__':
    # get topics to do hpe
    topics = sys.argv[1:-2]

    # start python wrapper
    rospy.init_node('hpe_skeleton')
    rate = rospy.Rate(3)
    rospy.loginfo("Publishing 3D human skeleton")
    rospy.loginfo("Topics: " + str(topics))
    node = HPE(topics)
    # cameras.append(node)
    while not rospy.is_shutdown():
        # node.hpe_3d()
        rate.sleep()
