#!/usr/bin/env python3
import sys
# sys.path.append('../include/')

import rospy
from sensor_msgs.msg import CameraInfo, Image, CompressedImage
from cv_bridge import CvBridge
from openpose import pyopenpose as op
# from pose_estimator_2d import openpose_estimator
# from pose_estimator_2d.estimator_2d import Estimator2D
import cv2
import numpy as np
import os
from pathlib import Path
# from bvh_skeleton import openpose_skeleton, h36m_skeleton, cmu_skeleton
import matplotlib.pyplot as plt
from hpe.msg import skeleton, person2D, keypoint2D
import functools
import operator

import mediapipe as mp
from mediapipe.tasks import python
from mediapipe.tasks.python import vision

from mediapipe import solutions
from mediapipe.framework.formats import landmark_pb2



class HPE():
    def __init__(self, topic):
        self.image_raw = None
        self.image_skeleton = None
        self.br = CvBridge()
        # self.opWrapper = opWrapper
        x = str(topic).split('/')
        self.publish_skeleton = rospy.Publisher(x[0] + "/skeleton", person2D, queue_size=2)
        # self.msg = skeleton()
        self.person2d_msg = person2D()

        if "compressed" in topic:
            publishing_topic = topic[:-len("compressed")] + 'skeleton'
            rospy.Subscriber(topic, CompressedImage, self.callback)
            # print(publishing_topic)
        else:
            rospy.Subscriber(topic, Image, self.callback)
            publishing_topic = topic + 'skeleton'
        self.pub = rospy.Publisher(publishing_topic, Image, queue_size=2)
        self.person2d_msg.header.frame_id = publishing_topic
        self.person2d_msg.header.stamp = rospy.Time.now()


        self.model_path = os.path.dirname(os.path.abspath(__file__))+'/../models/mediapipe/pose_landmarker_heavy.task'

        # self.camera = camera
        # self.show_images = show_images
        # self.dataset_name = '../images/human36m/processed/' + sector + '/' + action

        BaseOptions = mp.tasks.BaseOptions
        PoseLandmarker = mp.tasks.vision.PoseLandmarker
        PoseLandmarkerOptions = mp.tasks.vision.PoseLandmarkerOptions
        VisionRunningMode = mp.tasks.vision.RunningMode

        options = PoseLandmarkerOptions(
            base_options=BaseOptions(model_asset_buffer = open(self.model_path, "rb").read()),
            running_mode=VisionRunningMode.IMAGE)

        self.landmarker = PoseLandmarker.create_from_options(options)

        # self.pub_raw = rospy.Publisher(topic+'/check', Image, queue_size=10)

    def callback(self, msg):
        np_arr = np.fromstring(msg.data, np.uint8)
        self.image_raw = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
        # print("inside callback")
        # self.image_raw=self.br.imgmsg_to_cv2(msg)
        image_skeleton = self.hpe(self.image_raw)
        # if image_skeleton is None:
        #     print("image skeleton not working")
        # else:
        #     # print(image_skeleton.shape)
        self.pub.publish(self.br.cv2_to_imgmsg(image_skeleton))
        self.publish_skeleton.publish(self.person2d_msg)

    def hpe(self, img):
        mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=img)

        pose_landmarker_result = self.landmarker.detect(mp_image)
        annotated_image, keypoints_mp = self.draw_landmarks_on_image(mp_image.numpy_view(), pose_landmarker_result)

        height, width, c = annotated_image.shape
        keypoints = []

        for pt in keypoints_mp:
            not_normalized = solutions.drawing_utils._normalized_to_pixel_coordinates(pt.x, pt.y, width,
                                                                                      height)
            if not_normalized is not None:
                keypoint = [not_normalized[0], not_normalized[1], pt.visibility]
                keypoints.append(keypoint)
            else:
                keypoints.append([0,0,0])

        keypoints_msgs = []
        for keypoint in keypoints:
            msg = keypoint2D()
            msg.x = keypoint[0]
            msg.y = keypoint[1]
            msg.score = keypoint[2]
            keypoints_msgs.append(msg)
        self.person2d_msg.keypoints = keypoints_msgs

        return annotated_image

    def draw_landmarks_on_image(self,rgb_image, detection_result):
        pose_landmarks_list = detection_result.pose_landmarks
        annotated_image = np.copy(rgb_image)

        # Loop through the detected poses to visualize.
        if pose_landmarks_list != []:
            for idx in range(len(pose_landmarks_list)):
                pose_landmarks = pose_landmarks_list[idx]

                # Draw the pose landmarks.
                pose_landmarks_proto = landmark_pb2.NormalizedLandmarkList()
                pose_landmarks_proto.landmark.extend([
                    landmark_pb2.NormalizedLandmark(x=landmark.x, y=landmark.y, z=landmark.z, visibility=landmark.visibility)
                    for landmark in pose_landmarks
                ])
                solutions.drawing_utils.draw_landmarks(
                    annotated_image,
                    pose_landmarks_proto,
                    solutions.pose.POSE_CONNECTIONS,
                    solutions.drawing_styles.get_default_pose_landmarks_style())

            # print(pose_landmarks_proto.landmark)

            return annotated_image, pose_landmarks_proto.landmark
        else:
            return rgb_image, []

    def keypoints_to_image(self, keypoints, img, skeleton, kp_thresh, alpha=0.7):
        # Convert from plt 0-1 RGBA colors to 0-255 BGR colors for opencv.
        cmap = plt.get_cmap('rainbow')
        colors = [cmap(i) for i in np.linspace(0, 1, skeleton.keypoint_num)]
        colors = [(c[2] * 255, c[1] * 255, c[0] * 255) for c in colors]

        mask = img.copy()
        root = skeleton.root
        stack = [root]

        while stack:
            parent = stack.pop()
            p_idx = skeleton.keypoint2index[parent]
            p_pos = int(keypoints[p_idx, 0]), int(keypoints[p_idx, 1])
            p_score = keypoints[p_idx, 2] if kp_thresh is not None else None

            if kp_thresh is None or p_score > kp_thresh:
                cv2.circle(
                    mask, p_pos, radius=3,
                    color=colors[p_idx], thickness=-1, lineType=cv2.LINE_AA)

            for child in skeleton.children[parent]:
                if child not in skeleton.keypoint2index or skeleton.keypoint2index[child] < 0:
                    continue
                stack.append(child)
                c_idx = skeleton.keypoint2index[child]
                c_pos = int(keypoints[c_idx, 0]), int(keypoints[c_idx, 1])
                c_score = keypoints[c_idx, 2] if kp_thresh else None
                if kp_thresh is None or \
                        (p_score > kp_thresh and c_score > kp_thresh):
                    cv2.line(
                        mask, p_pos, c_pos,
                        color=colors[c_idx], thickness=2, lineType=cv2.LINE_AA)

        output_img = cv2.addWeighted(img, 1.0 - alpha, mask, alpha, 0)

        return output_img


if __name__ == '__main__':
    # get topics to do hpe
    topics = sys.argv[1:-2]

    # start python wrapper
    rospy.init_node('hpe_skeleton')
    rate = rospy.Rate(5)
    rospy.loginfo("Publishing images with human skeleton")
    rospy.loginfo("Topics: " + str(topics))
    cameras = []
    for topic in topics:
        node = HPE(topic)
        cameras.append(node)
    while not rospy.is_shutdown():
        rate.sleep()
