#!/usr/bin/env python3
import sys
# sys.path.append('../include/')

import rospy
from sensor_msgs.msg import CameraInfo, Image, CompressedImage
from cv_bridge import CvBridge
from openpose import pyopenpose as op
from pose_estimator_2d import openpose_estimator
from pose_estimator_2d.estimator_2d import Estimator2D
import cv2
import numpy as np
import os
from pathlib import Path
from bvh_skeleton import openpose_skeleton, h36m_skeleton, cmu_skeleton
import matplotlib.pyplot as plt
from hpe.msg import skeleton, keypoint2D, person2D
import functools
import operator


class HPE():
    def __init__(self, topic):
        self.rate = rospy.Rate(15)
        # self.pub = rospy.Publisher(topic+'/skeleton', Image, queue_size=10)
        self.image_raw = None
        self.image_skeleton = None
        self.br = CvBridge()
        # name= topic.split("/")[0]
        # rospy.Subscriber(topic, CompressedImage, self.callback)
        x = str(topic).split('/')
        self.publish_skeleton = rospy.Publisher(x[0] + "_skeleton", person2D, queue_size=2)
        # self.msg = skeleton()
        self.person2d_msg = person2D()

        params = {'model_folder': '/home/daniela/openpose/models/', 'render_pose': 0, 'model_pose': 'BODY_25',
                  'net_resolution': '320x176'}
        self.opWrapper = op.WrapperPython()
        self.opWrapper.configure(params)
        self.opWrapper.start()

        if "compressed" in topic:
            publishing_topic = topic[:-len("compressed")] + 'skeleton'
            rospy.Subscriber(topic, CompressedImage, self.callback)
            # print(publishing_topic)
        else:
            rospy.Subscriber(topic, Image, self.callback)
            publishing_topic = topic + 'skeleton'
        self.pub = rospy.Publisher(publishing_topic, Image, queue_size=2)
        self.person2d_msg.header.frame_id = publishing_topic
        self.person2d_msg.header.stamp = rospy.Time.now()

    def hpe(self, img):
        img_height = img.shape[0]
        img_width = img.shape[1]

        # keypoints_list_for_frame = []
        datum = op.Datum()
        op_skel = openpose_skeleton.OpenPoseSkeleton()
        datum.cvInputData = img
        self.opWrapper.emplaceAndPop(op.VectorDatum([datum]))

        # datum attributes:
        #  'cameraExtrinsics', 'cameraIntrinsics', 'cameraMatrix', 'cvInputData', 'cvOutputData', 'cvOutputData3D',
        #  'elementRendered', 'faceHeatMaps', 'faceKeypoints', 'faceKeypoints3D', 'faceRectangles', 'frameNumber',
        #  'handHeatMaps', 'handKeypoints', 'handKeypoints3D', 'handRectangles', 'id', 'inputNetData', 'name',
        #  'netInputSizes', 'netOutputSize', 'outputData', 'poseCandidates', 'poseHeatMaps', 'poseIds', 'poseKeypoints',
        #  'poseKeypoints3D', 'poseNetOutput', 'poseScores', 'scaleInputToNetInputs', 'scaleInputToOutput',
        #  'scaleNetToOutput', 'subId', 'subIdMax'
        # print(dir(datum))
        keypoints = datum.poseKeypoints
        # print(keypoint_list)

        # print(datum.poseScores)
        # print(keypoints_list_for_frame)
        if keypoints is not None:
            #create keypoint message
            keypoints_msgs = []
            for keypoint in keypoints.tolist()[0]:
                msg = keypoint2D()
                msg.x = keypoint[0]
                msg.y = keypoint[1]
                msg.score = keypoint[2]
                keypoints_msgs.append(msg)
            self.person2d_msg.keypoint_list = keypoints_msgs
            # keypoints_flat_list = functools.reduce(operator.iconcat, keypoint_list.tolist()[0], [])
            # self.msg.data = keypoints_flat_list
            img_sk = self.vis_2d_keypoints(
                keypoints=keypoints[0],
                img=img,
                skeleton=op_skel,
                kp_thresh=0.4,
            )
        else:
            self.person2d_msg.keypoint_list = []
            img_sk = self.image_raw

        # img_sk=1
        return img_sk

    def vis_2d_keypoints(self, keypoints, img, skeleton, kp_thresh, alpha=0.7, show_name=False):
        # Convert from plt 0-1 RGBA colors to 0-255 BGR colors for opencv.
        cmap = plt.get_cmap('rainbow')
        colors = [cmap(i) for i in np.linspace(0, 1, skeleton.keypoint_num)]
        colors = [(c[2] * 255, c[1] * 255, c[0] * 255) for c in colors]

        mask = img.copy()
        root = skeleton.root
        stack = [root]

        while stack:
            parent = stack.pop()
            p_idx = skeleton.keypoint2index[parent]
            p_pos = int(keypoints[p_idx, 0]), int(keypoints[p_idx, 1])
            p_score = keypoints[p_idx, 2] if kp_thresh is not None else None

            if kp_thresh is None or p_score > kp_thresh:
                cv2.circle(
                    mask, p_pos, radius=3,
                    color=colors[p_idx], thickness=-1, lineType=cv2.LINE_AA)
                if show_name:
                    cv2.putText(mask, parent, p_pos, cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0))

            for child in skeleton.children[parent]:
                if child not in skeleton.keypoint2index or skeleton.keypoint2index[child] < 0:
                    continue
                stack.append(child)
                c_idx = skeleton.keypoint2index[child]
                c_pos = int(keypoints[c_idx, 0]), int(keypoints[c_idx, 1])
                c_score = keypoints[c_idx, 2] if kp_thresh else None
                if kp_thresh is None or \
                        (p_score > kp_thresh and c_score > kp_thresh):
                    cv2.line(
                        mask, p_pos, c_pos,
                        color=colors[c_idx], thickness=2, lineType=cv2.LINE_AA)

        vis_result = cv2.addWeighted(img, 1.0 - alpha, mask, alpha, 0)

        return vis_result

    def callback(self, msg):
        np_arr = np.fromstring(msg.data, np.uint8)
        self.image_raw = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
        # print("inside callback")
        # self.image_raw=self.br.imgmsg_to_cv2(msg)
        self.image_skeleton = self.hpe(self.image_raw)

    def start(self):
        rospy.loginfo("Publishing images with human skeleton")
        while not rospy.is_shutdown():
            # rospy.loginfo('publishing image')
            if self.image_skeleton is not None:
                self.pub.publish(self.br.cv2_to_imgmsg(self.image_skeleton))
            elif self.image_raw is not None:
                self.pub.publish(self.br.cv2_to_imgmsg(self.image_raw))
            self.publish_skeleton.publish(self.person2d_msg)

            self.rate.sleep()


if __name__ == '__main__':
    # try:
    #     hpe()
    # except rospy.ROSInterruptException:
    #     pass
    rospy.init_node('hpe_skeleton')
    node = HPE(sys.argv[1])
    node.start()
