#!/usr/bin/env python3
import rospy
from hand_gesture_recognition.msg import HandsDetected
from sensor_msgs.msg import Image
from std_msgs.msg import Int32
import cv2
import numpy as np
from cv_bridge import CvBridge
import copy
import time
import mediapipe as mp
from yaml.loader import SafeLoader
from vision_config.vision_definitions import ROOT_DIR
from larcc_gestures.utils.hgr_utils import find_hands



class HandDetectionNode:
    def __init__(self):

        image_topic = rospy.get_param("/hgr/image_topic", default="/camera/color/image_raw")
        # image_topic = rospy.get_param("/hgr/image_topic", default="/camera/rgb/image_raw")
        roi_height = rospy.get_param("/hgr/height", default=100)
        roi_width = rospy.get_param("/hgr/width", default=100)  

        fps = rospy.get_param("/hgr/FPS/detection")

        # Initializations for MediaPipe to detect keypoints
        # self.left_hand_points = (16, 18, 20, 22)
        # self.right_hand_points = (15, 17, 19, 21)

        self.bridge = CvBridge()

        # self.mp_drawing = mp.solutions.drawing_utils
        # self.mp_drawing_styles = mp.solutions.drawing_styles
        # self.mp_pose = mp.solutions.pose
        # self.pose = self.mp_pose.Pose(static_image_mode=False,
        #                               model_complexity=2,
        #                               enable_segmentation=False,
        #                               min_detection_confidence=0.7)

        mp_data = {}
        mp_data["face_points"] = (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
        mp_data["threshold_points"] = (11, 12, 24, 25) # shoulders and hips
        mp_data["left_hand_points"] = (16, 18, 20, 22)
        mp_data["right_hand_points"] = (15, 17, 19, 21)
        mp_data["mp_drawing"] = mp.solutions.drawing_utils
        mp_data["mp_drawing_styles"] = mp.solutions.drawing_styles
        mp_data["mp_pose"] = mp.solutions.pose
        mp_data["pose"] = mp_data["mp_pose"].Pose(static_image_mode=False,
                                                model_complexity=2,
                                                enable_segmentation=False,
                                                min_detection_confidence=0.7)


        rospy.Subscriber(image_topic, Image, self.image_callback)
        pub_hands = rospy.Publisher("/hgr/hands", HandsDetected, queue_size=10)

        self.cv_image = None
        self.image_header = None
        self.msg = None
        self.bridge = CvBridge()


        print("Waiting!!")
        while True:
            if self.cv_image is not None:
                break
                
        try:
            while not rospy.is_shutdown():
                st = time.time()
                image = copy.deepcopy(self.cv_image)
                header = copy.deepcopy(self.image_header)

                # left_bounding, right_bounding, hand_right, hand_left, keypoints_image = self.find_hands(image, x_lim=int(roi_width / 2), y_lim=int(roi_height / 2))

                left_bounding, right_bounding, hand_right, hand_left, keypoints_image, mp_data["pose"] = find_hands(
                copy.deepcopy(self.cv_image), mp_data, x_lim=int(roi_width / 2), y_lim=int(roi_height / 2))


                left_b = [Int32(i) for i in left_bounding]
                right_b = [Int32(i) for i in right_bounding]

                hands = HandsDetected()
                hands.header = header
                hands.left_bounding_box = list(left_b)
                hands.right_bounding_box = list(right_b)

                if hand_left is not None:
                    # left_hand = cv2.resize(left_hand, (200, 200), interpolation=cv2.INTER_CUBIC)
                    hands.hand_left = self.bridge.cv2_to_imgmsg(hand_left, "rgb8")

                if hand_right is not None:
                    # right_hand = cv2.resize(right_hand, (200, 200), interpolation=cv2.INTER_CUBIC)
                    hands.hand_right = self.bridge.cv2_to_imgmsg(hand_right, "rgb8")
               

                while True:

                    if time.time() - st > 1/fps:
                        break
                    
                    time.sleep(1 / (fps * 1000))

                pub_hands.publish(hands)
                # print(f"DETECTION Running at {round(1 / (time.time() - st), 2)} FPS")

        except KeyboardInterrupt:
            print("Shutting down")
            cv2.destroyAllWindows()
        
    def image_callback(self, msg):
        self.cv_image = self.bridge.imgmsg_to_cv2(msg, "rgb8")
        self.image_header = msg.header
        self.msg = msg
        

if __name__ == '__main__':

    rospy.init_node("hand_detection", anonymous=True)

    hd = HandDetectionNode()
    try:
        rospy.spin()
    except KeyboardInterrupt:
        print("Shutting down")

    cv2.destroyAllWindows()
