#!/usr/bin/env python3
import rospy
from sensor_msgs.msg import Image
from hand_gesture_recognition.msg import HandsDetected, HandsClassified
from std_msgs.msg import Int32
import cv2
import numpy as np
from cv_bridge import CvBridge
import copy
import mediapipe as mp
import os
from larcc_gestures.utils.networks import InceptionV3
import torch
import time
from torchvision import transforms
import yaml
from yaml.loader import SafeLoader
from larcc_gestures.utils.hgr_utils import find_hands, take_decision
from vision_config.vision_definitions import ROOT_DIR


class HandGestureRecognition:
    def __init__(self, thresholds, cm,**kargs) -> None:

        # Get initial data from rosparams
        print(kargs)

        fps = rospy.get_param("/hgr/FPS/classification", default=30)

        # image_topic = rospy.get_param("/hgr/image_topic", default="/camera_3/rgb/image_raw")
        image_topic = rospy.get_param("/hgr/image_topic", default="/camera/color/image_raw")
        roi_height = rospy.get_param("/hgr/height", default=100)
        roi_width = rospy.get_param("/hgr/width", default=100)
        kargs["viz"] = True

        print(f"image_topic: {image_topic}")
        mp_data = {}
        mp_data["face_points"] = (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
        mp_data["threshold_points"] = (11, 12, 24, 25) # shoulders and hips
        mp_data["left_hand_points"] = (16, 18, 20, 22)
        mp_data["right_hand_points"] = (15, 17, 19, 21)
        mp_data["mp_drawing"] = mp.solutions.drawing_utils
        mp_data["mp_drawing_styles"] = mp.solutions.drawing_styles
        mp_data["mp_pose"] = mp.solutions.pose
        mp_data["pose"] = mp_data["mp_pose"].Pose(static_image_mode=False,
                                                model_complexity=2,
                                                enable_segmentation=False,
                                                min_detection_confidence=0.7)


        rospy.Subscriber(image_topic, Image, self.image_callback)
        pub_classification = rospy.Publisher("/hgr/classification", HandsClassified, queue_size=10)
        pub_hands = rospy.Publisher("/hgr/hands", HandsDetected, queue_size=10)

        self.cv_image = None
        self.header = None
        self.bridge = CvBridge()

        # Initialize variables for classification
        gestures = ["A", "F", "L", "Y", "NONE"]
        self.thresholds = thresholds

        buffer_left = [4] * kargs["n_frames"] # Initializes the buffer with 5 "NONE" gestures
        buffer_right = [4] * kargs["n_frames"]

        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        print("Training device: ", self.device)

        self.model = InceptionV3(4, 0.0001, unfreeze_layers=list(np.arange(13, 20)), class_features=2048, device=self.device,
                    con_features=kargs["con_features"])
        self.model.name = kargs["model_name"]

        trained_weights = torch.load(f'{os.getenv("HOME")}/models/{self.model.name}/{self.model.name}.pth', map_location=torch.device(self.device))
        self.model.load_state_dict(trained_weights)

        self.model.eval()

        self.model.to(self.device)

        mean = np.array([0.5, 0.5, 0.5])
        std = np.array([0.25, 0.25, 0.25])

        self.data_transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(299),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])
    
        print("Waiting!!")
        while True:
            if self.cv_image is not None:
                break

        if kargs["viz"]:
            cv2.namedWindow("MediaPipe Image", cv2.WINDOW_NORMAL)

        self.predict(self.cv_image, flip=False)
        while not rospy.is_shutdown():

            while not rospy.is_shutdown():
                if self.cv_image is not None:
                    break

            st = time.time()

            
            header = self.header

            # left_bounding, right_bounding, hand_right, hand_left, keypoints_image = self.find_hands(copy.deepcopy(self.cv_image), x_lim=int(roi_width / 2), y_lim=int(roi_height / 2))

            left_bounding, right_bounding, hand_right, hand_left, keypoints_image, mp_data["pose"], hand_validity = find_hands(
                copy.deepcopy(self.cv_image), mp_data, x_lim=int(roi_width / 2), y_lim=int(roi_height / 2))


            left_b = [Int32(i) for i in left_bounding]
            right_b = [Int32(i) for i in right_bounding]

            hands = HandsDetected()
            hands.header = header
            hands.left_bounding_box = list(left_b)
            hands.right_bounding_box = list(right_b)

            if hand_left is not None:
                left_frame = copy.deepcopy(cv2.cvtColor(hand_left, cv2.COLOR_BGR2RGB))

                outputs, preds = self.predict(left_frame, flip=True)

                pred_left, confid_left, buffer_left = take_decision(outputs, preds, thresholds, buffer_left, cm, min_coef=kargs["min_coef"])

            else:
                pred_left = 4
                confid_left = 1.0
            
            if hand_right is not None:
                right_frame = copy.deepcopy(cv2.cvtColor(hand_right, cv2.COLOR_BGR2RGB))
                
                outputs, preds = self.predict(right_frame, flip=False)

                pred_right, confid_right, buffer_right = take_decision(outputs, preds, thresholds, buffer_right, cm, min_coef=kargs["min_coef"])
            
            else:
                pred_right = 4 
                confid_right = 1.0
            
            msg_classification = HandsClassified()
            msg_classification.header = header
            msg_classification.hand_right = gestures[pred_right]
            msg_classification.hand_left = gestures[pred_left]
            msg_classification.confid_right = confid_right
            msg_classification.confid_left = confid_left

            if kargs["viz"]:
                cv2.imshow("MediaPipe Image", cv2.cvtColor(keypoints_image, cv2.COLOR_BGR2RGB))
                key = cv2.waitKey(1)

                if key == ord('q'):
                    break

            while True:

                if time.time() - st > 1/fps:
                    break
                
                time.sleep(1 / (fps * 1000))

            self.cv_image = None
            pub_hands.publish(hands)
            pub_classification.publish(msg_classification)
            print(f"Class: {gestures[pred_left]}")
            print(f"CLASSIFICATION Running at {round(1 / (time.time() - st), 2)} FPS")

        cv2.destroyAllWindows()


    def image_callback(self, msg):
        
        self.cv_image = self.bridge.imgmsg_to_cv2(msg, "rgb8")
        self.header = msg.header

    def predict(self, hand, flip):
        
        if flip:
            hand = cv2.flip(hand, 1)

        im_norm = self.data_transform(hand).unsqueeze(0)
        im_norm = im_norm.to(self.device)

        with torch.no_grad():   
            outputs, _ = self.model(im_norm)
            _, preds = torch.max(outputs, 1)

        return outputs, preds

if __name__=="__main__":

    rospy.init_node("hand_gesture_recognition", anonymous=True)

    model_name = rospy.get_param("/hgr/model_name", default="InceptionV3")

    # with open(f'{ROOT_DIR}/larcc_gestures/config/model/{model_name}.yaml') as f:
    #     data = yaml.load(f, Loader=SafeLoader)
    #     print(data)

    # with open(f'{ROOT_DIR}/larcc_gestures/config/model/thresholds.yaml') as f:
    #     t = yaml.load(f, Loader=SafeLoader)
    #     print(t)


    with open(f'{os.path.expanduser("~")}/catkin_ws/src/larcc/larcc_gestures/config/model/{model_name}.yaml') as f:
        data = yaml.load(f, Loader=SafeLoader)
        print(data)

    with open(f'{os.path.expanduser("~")}/catkin_ws/src/larcc/larcc_gestures/config/model/thresholds.yaml') as f:
        t = yaml.load(f, Loader=SafeLoader)
        print(t)

    thresholds = t["thresholds"][data["threshold_choice"]]
    cm = t["confusion_matrix"][data["threshold_choice"]]
    print(thresholds)
    print(cm)

    hd = HandGestureRecognition(thresholds, cm, **data)
    try:
        rospy.spin()
    except KeyboardInterrupt:
        print("Shutting down")

    cv2.destroyAllWindows()
