#!/usr/bin/env python3
import sys
# sys.path.append('../include/')

import rospy
from sensor_msgs.msg import CameraInfo, Image, CompressedImage
from cv_bridge import CvBridge
from openpose import pyopenpose as op
from pose_estimator_2d import openpose_estimator
from pose_estimator_2d.estimator_2d import Estimator2D
import cv2
import numpy as np
import os
from pathlib import Path
from bvh_skeleton import openpose_skeleton, h36m_skeleton, cmu_skeleton
import matplotlib.pyplot as plt
from hpe.msg import skeleton, person2D, keypoint2D
import functools
import operator


class HPE():
    def __init__(self, topic, opWrapper):
        self.image_raw = None
        self.image_skeleton = None
        self.br = CvBridge()
        self.opWrapper = opWrapper
        x = str(topic).split('/')
        self.publish_skeleton = rospy.Publisher(x[0] + "_skeleton", person2D, queue_size=2)
        # self.msg = skeleton()
        self.person2d_msg = person2D()

        if "compressed" in topic:
            publishing_topic = topic[:-len("compressed")] + 'skeleton'
            rospy.Subscriber(topic, CompressedImage, self.callback)
            # print(publishing_topic)
        else:
            rospy.Subscriber(topic, Image, self.callback)
            publishing_topic = topic + 'skeleton'
        self.pub = rospy.Publisher(publishing_topic, Image, queue_size=2)
        self.person2d_msg.header.frame_id = publishing_topic
        self.person2d_msg.header.stamp = rospy.Time.now()

        # self.pub_raw = rospy.Publisher(topic+'/check', Image, queue_size=10)

    def callback(self, msg):
        np_arr = np.fromstring(msg.data, np.uint8)
        self.image_raw = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
        # print("inside callback")
        # self.image_raw=self.br.imgmsg_to_cv2(msg)
        image_skeleton = self.hpe(self.image_raw)
        # if image_skeleton is None:
        #     print("image skeleton not working")
        # else:
        #     # print(image_skeleton.shape)
        self.pub.publish(self.br.cv2_to_imgmsg(image_skeleton))
        self.publish_skeleton.publish(self.person2d_msg)

    def hpe(self, img):
        datum = op.Datum()
        datum.cvInputData = img
        self.opWrapper.emplaceAndPop(op.VectorDatum([datum]))
        keypoints = datum.poseKeypoints
        if keypoints is not None:
            #create keypoint message
            keypoints_msgs = []
            for keypoint in keypoints.tolist()[0]:
                msg = keypoint2D()
                msg.x = keypoint[0]
                msg.y = keypoint[1]
                msg.score = keypoint[2]
                keypoints_msgs.append(msg)
            self.person2d_msg.keypoint_list = keypoints_msgs
            # keypoints_flat_list = functools.reduce(operator.iconcat, keypoint_list.tolist()[0], [])
            # self.msg.data = keypoints_flat_list
            # print(keypoint_list.tolist()[0])
            img_sk = self.keypoints_to_image(keypoints=keypoints[0], img=img, skeleton=openpose_skeleton.OpenPoseSkeleton(), kp_thresh=0.4, )
        else:
            self.person2d_msg.keypoint_list = []
            img_sk = self.image_raw

        return img_sk

    def keypoints_to_image(self, keypoints, img, skeleton, kp_thresh, alpha=0.7):
        # Convert from plt 0-1 RGBA colors to 0-255 BGR colors for opencv.
        cmap = plt.get_cmap('rainbow')
        colors = [cmap(i) for i in np.linspace(0, 1, skeleton.keypoint_num)]
        colors = [(c[2] * 255, c[1] * 255, c[0] * 255) for c in colors]

        mask = img.copy()
        root = skeleton.root
        stack = [root]

        while stack:
            parent = stack.pop()
            p_idx = skeleton.keypoint2index[parent]
            p_pos = int(keypoints[p_idx, 0]), int(keypoints[p_idx, 1])
            p_score = keypoints[p_idx, 2] if kp_thresh is not None else None

            if kp_thresh is None or p_score > kp_thresh:
                cv2.circle(
                    mask, p_pos, radius=3,
                    color=colors[p_idx], thickness=-1, lineType=cv2.LINE_AA)

            for child in skeleton.children[parent]:
                if child not in skeleton.keypoint2index or skeleton.keypoint2index[child] < 0:
                    continue
                stack.append(child)
                c_idx = skeleton.keypoint2index[child]
                c_pos = int(keypoints[c_idx, 0]), int(keypoints[c_idx, 1])
                c_score = keypoints[c_idx, 2] if kp_thresh else None
                if kp_thresh is None or \
                        (p_score > kp_thresh and c_score > kp_thresh):
                    cv2.line(
                        mask, p_pos, c_pos,
                        color=colors[c_idx], thickness=2, lineType=cv2.LINE_AA)

        output_img = cv2.addWeighted(img, 1.0 - alpha, mask, alpha, 0)

        return output_img


if __name__ == '__main__':
    # get topics to do hpe
    topics = sys.argv[1:-2]

    # start python wrapper
    params = {'model_folder': '/home/daniela/openpose/models/', 'render_pose': 0, 'model_pose': 'BODY_25',
              'net_resolution': '320x176'}
    opWrapper = op.WrapperPython()
    opWrapper.configure(params)
    opWrapper.start()

    rospy.init_node('hpe_skeleton')
    rate = rospy.Rate(5)
    rospy.loginfo("Publishing images with human skeleton")
    rospy.loginfo("Topics: " + str(topics))
    cameras = []
    for topic in topics:
        node = HPE(topic, opWrapper)
        cameras.append(node)
    while not rospy.is_shutdown():
        rate.sleep()
