#!/usr/bin/env python3
import rospy
from sensor_msgs.msg import Image
from hand_gesture_recognition.msg import HandsDetected, HandsClassified
from std_msgs.msg import Int32
import cv2
import numpy as np
from cv_bridge import CvBridge
import copy
import mediapipe as mp
import os
from larcc_gestures.utils.networks import InceptionV3
import torch
import time
from torchvision import transforms
import yaml
from yaml.loader import SafeLoader
from larcc_gestures.utils.hgr_utils import find_hands, take_decision
from vision_config.vision_definitions import ROOT_DIR, USERNAME


class HandGestureRecognition:
    def __init__(self, thresholds, cm,**kargs) -> None:
        
        dataset_path = f"/home/{USERNAME}/rosbags/manel"
        cv2.namedWindow("Recording", cv2.WINDOW_NORMAL)
        width = 640
        height = 480
        # width = 100
        # height = 100
        frame_rate = 10
        video_name = 'manel_constraints.mp4'


        # Get initial data from rosparams
        print(kargs)

        fps = 10

        roi_height = 100
        roi_width = 100

        mp_data = {}
        mp_data["face_points"] = (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
        mp_data["threshold_points"] = (11, 12, 24, 25) # shoulders and hips
        mp_data["left_hand_points"] = (16, 18, 20, 22)
        mp_data["right_hand_points"] = (15, 17, 19, 21)
        mp_data["mp_drawing"] = mp.solutions.drawing_utils
        mp_data["mp_drawing_styles"] = mp.solutions.drawing_styles
        mp_data["mp_pose"] = mp.solutions.pose
        mp_data["pose"] = mp_data["mp_pose"].Pose(static_image_mode=False,
                                                model_complexity=2,
                                                enable_segmentation=False,
                                                min_detection_confidence=0.7)

        self.cv_image = None
        self.header = None
        self.bridge = CvBridge()

        # Recording

        writer = cv2.VideoWriter(video_name, cv2.VideoWriter_fourcc(*'DIVX'), frame_rate, (width,height))


        # Initialize variables for classification
        gestures = ["A", "F", "L", "Y", "NONE"]
        font_scale = 1
        self.thresholds = thresholds

        buffer_left = [4] * kargs["n_frames"] # Initializes the buffer with 5 "NONE" gestures
        buffer_right = [4] * kargs["n_frames"]

        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        print("Training device: ", self.device)

        self.model = InceptionV3(4, 0.0001, unfreeze_layers=list(np.arange(13, 20)), class_features=2048, device=self.device,
                    con_features=kargs["con_features"])
        self.model.name = kargs["model_name"]

        trained_weights = torch.load(f'{os.getenv("HOME")}/models/{self.model.name}/{self.model.name}.pth', map_location=torch.device(self.device))
        self.model.load_state_dict(trained_weights)

        self.model.eval()

        self.model.to(self.device)

        mean = np.array([0.5, 0.5, 0.5])
        std = np.array([0.25, 0.25, 0.25])

        self.data_transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(299),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])

        blank_image = np.zeros((height,width,3), np.uint8)
        self.predict(blank_image, flip=True)

        res = os.listdir(f"{dataset_path}")

        num_list = []
        for file in res:

            num = int(''.join(filter(lambda i: i.isdigit(), file)))
            num_list.append(num)

        list1, list2 = zip(*sorted(zip(num_list, res)))

        if kargs["viz"]:
            cv2.namedWindow("MediaPipe Image", cv2.WINDOW_NORMAL)


        for img_name in list2:
            st = time.time()

            
            # img =  cv2.imread(f"{dataset_path}/{img_name}") 
            img =  cv2.cvtColor(cv2.imread(f"{dataset_path}/{img_name}"), cv2.COLOR_BGR2RGB) 
            # left_bounding, right_bounding, hand_right, hand_left, keypoints_image = self.find_hands(copy.deepcopy(self.cv_image), x_lim=int(roi_width / 2), y_lim=int(roi_height / 2))
            nothing = copy.deepcopy(img)
            left_bounding, right_bounding, hand_right, hand_left, keypoints_image, mp_data["pose"], hand_validity = find_hands(
                img, mp_data, x_lim=int(roi_width / 2), y_lim=int(roi_height / 2))


            # if hand_validity[1] and hand_left is not None:
            #     left_frame = copy.deepcopy(cv2.cvtColor(hand_left, cv2.COLOR_BGR2RGB))

            #     outputs, preds = self.predict(left_frame, flip=True)

            #     pred_left, confid_left, buffer_left = take_decision(outputs, preds, thresholds, buffer_left, cm, min_coef=kargs["min_coef"])

            # else:
            #     if hand_left is None:
            #         hand_left = np.zeros((height,width,3), np.uint8)
            #     pred_left = 4
            #     confid_left = 1.0
            
            # if hand_validity[0] and hand_right is not None:
            #     right_frame = copy.deepcopy(cv2.cvtColor(hand_right, cv2.COLOR_BGR2RGB))
                
            #     outputs, preds = self.predict(right_frame, flip=False)

            #     pred_right, confid_right, buffer_right = take_decision(outputs, preds, thresholds, buffer_right, cm, min_coef=kargs["min_coef"])
            
            # else:
            #     if hand_right is None:
            #         hand_right = np.zeros((height,width,3), np.uint8)
            #     pred_right = 4 
            #     confid_right = 1.0
        
            # img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)

            # label_left = f"{gestures[pred_left]} {round(confid_left * 100, 2)}%"

            # label_right = f"{gestures[pred_right]} {round(confid_right * 100, 2)}%"

            # if hand_validity[0]:
            #     color_r = (0, 0, 255)
            # else:
            #     color_r = (255, 0, 0)

            # if hand_validity[1]:
            #     color_l = (0, 0, 255)
            # else:
            #     color_l = (255, 0, 0)



            # if left_bounding is not None:
            #     box_left = left_bounding

            #     box_left_tl =  (box_left[0], box_left[1])
            #     box_left_br =  (box_left[2], box_left[3])

            #     cv2.rectangle(img, box_left_tl, box_left_br, color_l, 2)

            #     if box_left_tl[1] < 30 and box_left_tl[1] > 0:
            #         cv2.putText(img, label_left, (box_left_tl[0], box_left_br[1]+25), cv2.FONT_HERSHEY_PLAIN, font_scale, color_l, 2)
            #     else:   
            #         cv2.putText(img, label_left, (box_left_tl[0], box_left_tl[1]-10), cv2.FONT_HERSHEY_PLAIN, font_scale, color_l, 2)
            
            # if right_bounding is not None:

            #     box_right = right_bounding

            #     box_right_tl =  (box_right[0], box_right[1])
            #     box_right_br =  (box_right[2], box_right[3])

            #     cv2.rectangle(img, box_right_tl, box_right_br, color_r, 2)

            #     if box_right_tl[1] < 30 and box_right_tl[1] > 0:
            #         cv2.putText(img, label_right, ((box_right_tl[0], box_right_br[1]+25)), cv2.FONT_HERSHEY_PLAIN, font_scale, color_r, 2)
            #     else:   
            #         cv2.putText(img, label_right, (box_right_tl[0], box_right_tl[1]-10), cv2.FONT_HERSHEY_PLAIN, font_scale, color_r, 2)

            img = cv2.cvtColor(keypoints_image, cv2.COLOR_RGB2BGR)

            cv2.imshow("Recording", img)
            writer.write(img)

            # cv2.imshow("Recording", cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
            # writer.write(cv2.cvtColor(img, cv2.COLOR_RGB2BGR))

            # cv2.imshow("Recording", img)
            # writer.write(img)

            key = cv2.waitKey(1)

            # print(f"VISUALIZATION Running at {round(1 / (time.time() - st), 2)} FPS")

            if key == ord("q"):
                break

            if kargs["viz"]:
                cv2.imshow("MediaPipe Image", cv2.cvtColor(keypoints_image, cv2.COLOR_BGR2RGB))
                key = cv2.waitKey()

                if key == ord('q'):
                    break

            while True:

                if time.time() - st > 1/fps:
                    break
                
                time.sleep(1 / (fps * 1000))

            self.cv_image = None
            # print(f"Class: {gestures[pred_left]}")
            print(f"CLASSIFICATION Running at {round(1 / (time.time() - st), 2)} FPS")

        cv2.destroyAllWindows()


    def predict(self, hand, flip):
        
        if flip:
            hand = cv2.flip(hand, 1)

        im_norm = self.data_transform(hand).unsqueeze(0)
        im_norm = im_norm.to(self.device)

        with torch.no_grad():   
            outputs, _ = self.model(im_norm)
            _, preds = torch.max(outputs, 1)

        return outputs, preds

if __name__=="__main__":

    model_name = "InceptionV3"

    with open(f'{os.path.expanduser("~")}/catkin_ws/src/larcc/larcc_gestures/config/model/{model_name}.yaml') as f:
        data = yaml.load(f, Loader=SafeLoader)
        print(data)

    with open(f'{os.path.expanduser("~")}/catkin_ws/src/larcc/larcc_gestures/config/model/thresholds.yaml') as f:
        t = yaml.load(f, Loader=SafeLoader)
        print(t)

    thresholds = t["thresholds"][data["threshold_choice"]]
    cm = t["confusion_matrix"][data["threshold_choice"]]
    print(thresholds)
    print(cm)

    hd = HandGestureRecognition(thresholds, cm, **data)
    try:
        rospy.spin()
    except KeyboardInterrupt:
        print("Shutting down")

    cv2.destroyAllWindows()
