#!/usr/bin/env python3
import sys
from pathlib import Path
import os

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
                + '/scripts/')

import rospy
from sensor_msgs.msg import CameraInfo, Image, CompressedImage
from visualization_msgs.msg import Marker
from geometry_msgs.msg import Point, TransformStamped
from tf2_msgs.msg import TFMessage
from cv_bridge import CvBridge

import cv2
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
from hpe.msg import skeleton, person2D, keypoint2D, person3D, keypoint3D
import functools
import operator
from functools import partial

from utils.links import joints_mediapipe, links_mediapipe
from utils.transforms import getTransform, getChain, getAggregateTransform
from urdf_parser_py.urdf import URDF
from atom_core.naming import generateKey
from tf.transformations import quaternion_from_matrix, quaternion_from_euler
from optimization.objective_function_ros import objectiveFunction

if os.environ.get('USER') == 'mike':
    from OptimizationUtils import Optimizer
    from KeyPressManager import WindowManager
else:
    import OptimizationUtils.OptimizationUtils as OptimizationUtils
    from OptimizationUtils.OptimizationUtils import Optimizer
    from OptimizationUtils.KeyPressManager import WindowManager


class HPE():
    def __init__(self, topics):
        self.sll = True
        self.sff = True

        # Publishers - 3D skeleton keypoint_list and markers to draw skeleton
        # x = str(topic).split('/')
        self.publish_skeleton = rospy.Publisher("/3D_skeleton", person3D, queue_size=2)
        self.person3d_msg = person3D()
        self.marker_pub = rospy.Publisher("/visualization_marker", Marker, queue_size=10)
        self.person3d_msg.header.frame_id = "world"
        self.person3d_msg.header.stamp = rospy.Time.now()
        self.frame_idx = 0
        self.selected_camera_topic = ''

        self.lock_frame = False
        self.current_frame = 0

        self.poses3d = {}

        self.keypoint_list = {}
        self.subscribers = {}


        for topic in topics:
            self.keypoint_list[topic] = np.zeros((33, 3))
            self.subscribers[topic] = rospy.Subscriber(topic, person2D, self.callback, topic)

        self.camera_detected = {topic: False for topic in topics}
        self.received = {topic: False for topic in topics}
        self.message_buffers = {topic: [] for topic in topics}
        self.sync_time_window = rospy.Duration(0.2)  # 100 ms time window for synchronization

        # sub=rospy.Subscriber('/camera_1/rgb/camera_info', CameraInfo, self.callback_camera_info)
        # sub.unregister()

        # self.get_camera_extrinsics()
        self.transform_tree_dict = self.get_transform_tree_dict(
            os.path.dirname(os.path.abspath(__file__)) + '/../xacro/larcc.urdf')

        self.topics = topics
        self.valid_kpt = False
        self.cameras = {}
        for topic in self.topics:
            topic = str(topic).split('/')[0]
            self.cameras[topic] = {}
            self.cameras[topic]['frames'] = {}
            if self.selected_camera_topic == '':
                self.selected_camera_topic = topic

        self.get_camera_extrinsics()
        self.get_camera_intrinsics()

        self.args = {
            'has_ground_truth': False,
            'dataset_name': "mpi",
            '2d_detector': "mediapipe",
            'skip_frame_to_frame_residuals': self.sff,
            'skip_link_length_residuals': self.sll,
            'debug': False
        }



    def callback(self, msg, topic):
        self.message_buffers[topic].append(msg)
        self.sync_messages()

    def sync_messages(self):
        # # Find the earliest timestamp in all buffers
        # earliest_time = min([buffer[0].header.stamp for buffer in self.message_buffers.values() if buffer])
        #
        # # Check if all topics have messages within the sync_time_window
        # synchronized_msgs = {}
        # for topic, buffer in self.message_buffers.items():
        #     for msg in buffer:
        #         if abs(msg.header.stamp - earliest_time) <= self.sync_time_window:
        #             synchronized_msgs[topic] = msg
        #             buffer.remove(msg)
        #             break

        synchronized_msgs = {}

        # Collect the earliest message timestamps for each topic
        earliest_timestamps = {topic: buffer[0].header.stamp for topic, buffer in self.message_buffers.items() if
                               buffer}

        if not earliest_timestamps:
            return

        # Reference timestamp to compare against others
        reference_topic = min(earliest_timestamps, key=earliest_timestamps.get)
        reference_time = earliest_timestamps[reference_topic]

        for topic, buffer in self.message_buffers.items():
            # Look for the first message in the buffer that falls within the sync_time_window compared to reference_time
            for msg in buffer:
                if abs(msg.header.stamp - reference_time) <= self.sync_time_window:
                    synchronized_msgs[topic] = msg
                    buffer.remove(msg)
                    break

        print(str(len(synchronized_msgs)) + " synchronized messages")
        if len(synchronized_msgs) >= 2:  # All topics have synchronized messages
            print(str(len(synchronized_msgs))+" synchronized messages")
            self.process_synchronized_msgs(synchronized_msgs)
            self.global_frame_idx += 1

    def process_synchronized_msgs(self, msgs):
        joints = {}
        for topic, msg in msgs.items():
            if msg.keypoints:
                idx = 0
                n_valid = 0
                for kpt in msg.keypoints:
                    self.keypoint_list[topic][idx][0] = kpt.x
                    self.keypoint_list[topic][idx][1] = kpt.y
                    self.keypoint_list[topic][idx][2] = kpt.score

                    x, y = kpt.x, kpt.y
                    confidence = kpt.score
                    valid = (not x == 0) and (not y == 0)
                    joints[idx] = {'x': x, 'y': y, 'confidence': confidence, 'valid': valid,
                                   'x_proj': 0.0, 'y_proj': 0.0}
                    if valid:
                        n_valid += 1

                    idx += 1

                topic = str(topic).split('/')[0]
                self.cameras[topic]['frames'][str(self.global_frame_idx)] = {'joints': joints, 'valid_frame': True}
                # if n_valid > 7:
                #     self.cameras[topic]['frames'][str(self.global_frame_idx)] = {'joints': joints, 'valid_frame': True}
                # else:
                #     self.cameras[topic]['frames'][str(self.global_frame_idx)] = {'joints': joints, 'valid_frame': False}

                if not self.valid_kpt:
                    self.valid_kpt = True

                if not self.lock_frame:
                    self.current_frame = self.global_frame_idx - 1
                self.camera_detected[topic] = True


    def get_transform_tree_dict(self, file):
        xml_robot = URDF.from_xml_file(file)
        # print(xml_robot)
        dict = {}

        for joint in xml_robot.joints:
            child = joint.child
            parent = joint.parent
            xyz = joint.origin.xyz
            rpy = joint.origin.rpy
            key = generateKey(parent, child)

            dict[key] = {
                'child': child,
                'parent': parent,
                'trans': xyz,
                'quat': list(quaternion_from_euler(rpy[0], rpy[1], rpy[2], axes='sxyz'))
            }

        return dict

    def get_camera_extrinsics(self):
        for topic_idx, camera in self.cameras.items():
            self.cameras[topic_idx]['extrinsics'] = getTransform('world', topic_idx + '_link', self.transform_tree_dict)
        return

    def get_camera_intrinsics(self):
        msg = CameraInfo()
        for topic_idx, camera in self.cameras.items():
            msg = rospy.wait_for_message('/' + topic_idx + '/rgb/camera_info', CameraInfo, timeout=None)
            self.cameras[topic_idx]['intrinsics'] = np.reshape(msg.K, (3, 3))
            self.cameras[topic_idx]['distortion'] = np.reshape(msg.D, (5, 1))
            self.cameras[topic_idx]['height'] = msg.height
            self.cameras[topic_idx]['width'] = msg.width
        return

    def hpe_3d(self):
        self.lock_frame = True

        joints_to_use = []
        for _, link in links_mediapipe.items():  # derive al joints to use based on the parent and childs
            joints_to_use.append(link['parent'])
            joints_to_use.append(link['child'])
        joints_to_use = list(set(joints_to_use))  # remove repetitions

        self.poses3d[self.current_frame] = {}
        for joint_key in joints_to_use:
            self.poses3d[self.current_frame][joint_key] = {'X': 0.0, 'Y': 0.0, 'Z': 0.0}

        # calculate number of valid frames:
        n_valid_frames = 0
        for topic_idx, camera in self.cameras.items():
            # print(topic_idx)
            if str(self.current_frame) in self.cameras[topic_idx]['frames']:
            # if self.camera_detected[topic_idx]:
                print("Frame " + str(self.current_frame) + " has valid camera " + str(topic_idx) + ".")
                if self.cameras[topic_idx]['frames'][str(self.current_frame)]['valid_frame']:
                    n_valid_frames += 1
        print("Frame " + str(self.current_frame) + " has " + str(n_valid_frames) + " valid cameras.")

        if n_valid_frames > 2:
            print("OPTIMIZING WITH " + str(n_valid_frames) + " VALID CAMERAS FOR CURRENT FRAME")

            opt = Optimizer()
            opt.addDataModel('args', self.args)
            opt.addDataModel('cameras', self.cameras)
            opt.addDataModel('poses3d', self.poses3d)
            opt.addDataModel('joints_to_use', joints_to_use)
            errors_per_iteration = {}
            opt.addDataModel('errors', errors_per_iteration)
            # opt.addDataModel('current_frame',)

            frames_to_use = []
            frames_to_use.append(self.current_frame)

            opt.addDataModel('frames_to_use', frames_to_use)

            def getJointXYZ(data, frame_key, joint_key):
                # print(frame_key, joint_key)
                d = data[frame_key][joint_key]
                X, Y, Z = d['X'], d['Y'], d['Z']
                return [X, Y, Z]

            def setJointXYZ(data, values, frame_key, joint_key):
                X, Y, Z = values[0], values[1], values[2]
                d = data[frame_key][joint_key]
                d['X'] = X
                d['Y'] = Y
                d['Z'] = Z

            for joint_key in joints_to_use:
                group_name = 'frame_' + str(self.frame_idx) + '_joint_' + str(joint_key)
                opt.pushParamVector(group_name, data_key='poses3d',
                                    getter=partial(getJointXYZ, frame_key=self.frame_idx, joint_key=joint_key),
                                    setter=partial(setJointXYZ, frame_key=self.frame_idx, joint_key=joint_key),
                                    suffix=['_X', '_Y', '_Z'])

            # opt.printParameters()
            # ----------------------------------------------
            # Define the objective function
            # ----------------------------------------------
            opt.setObjectiveFunction(objectiveFunction)

            for camera_key, camera in self.cameras.items():
                for frame_key, frame in camera['frames'].items():
                    # print(frame)
                    for joint_key in joints_to_use:
                        joint_idx = joints_mediapipe[joint_key]['idx']
                        joint = frame['joints'][joint_idx]

                        if joint['valid']:  # skip invalid joints
                            parameter_pattern = 'frame_' + str(frame_key) + '_joint_' + joint_key
                            residual_key = 'projection_sensor_' + camera_key + '_frame_' + str(
                                frame_key) + '_joint_' + joint_key

                            params = opt.getParamsContainingPattern(
                                pattern=parameter_pattern)  # get all weight related parameters
                            opt.pushResidual(name=residual_key, params=params)
                        else:
                            print("camera " + str(camera_key) + ": Joint " + str(
                                joint_key) + ' is ' + 'invalid. (residuals)')
                            continue

            # Frame distance residuals
            if not self.sff:
                for initial_frame_key, final_frame_key in zip(list(self.poses3d.keys())[:-1],
                                                              list(self.poses3d.keys())[1:]):
                    if not int(final_frame_key) - int(initial_frame_key) == 1:  # frames are not consecutive
                        # print('Frame ' + initial_frame_key + ' and ' + final_frame_key + ' are not consecutive.')
                        continue

                    initial_pose = self.poses3d[initial_frame_key]
                    # final_pose = poses3d[final_frame_key]

                    for joint_key in initial_pose.keys():
                        residual_key = 'consecutive_frame_' + initial_frame_key + '_frame_' + final_frame_key + '_joint_' + joint_key

                        params = opt.getParamsContainingPattern(
                            pattern='frame_' + initial_frame_key + '_joint_' + joint_key + '_')  # get parameters for frame initial_frame
                        params.extend(opt.getParamsContainingPattern(
                            pattern='frame_' + final_frame_key + '_joint_' + joint_key + '_'))  # get parameters for frame final_frame
                        opt.pushResidual(name=residual_key, params=params)

            # Link length residuals
            if not self.sll:
                for link_key, link in links_mediapipe.items():
                    print('Link ' + link['parent'] + '-' + link['child'])

                    for frame_key, frame in self.poses3d.items():  # compute residual as distance from reference link length

                        params = opt.getParamsContainingPattern(
                            pattern='frame_' + frame_key + '_joint_' + link['parent'] + '_')
                        params.extend(
                            opt.getParamsContainingPattern(
                                pattern='frame_' + frame_key + '_joint_' + link['child'] + '_'))

                        residual_key = 'length_frame_' + frame_key + '_joint_' + link['parent'] + '_joint_' + link[
                            'child']
                        opt.pushResidual(name=residual_key, params=params)

            opt.printResiduals()
            print(len(list(opt.residuals.keys())))

            opt.computeSparseMatrix()
            opt.printSparseMatrix()
            opt.startOptimization(optimization_options={'x_scale': 'jac', 'ftol': 1e-7,
                                                        'xtol': 1e-7, 'gtol': 1e-7,
                                                        'diff_step': None})  # , 'max_nfev': 1

        # print(self.poses3d)
        self.lock_frame = False

            # self.frame_idx += 1
        return

    def draw_3D_skeleton(self, keypoints3d):
        marker = Marker()

        marker.type = marker.LINE_LIST
        marker.action = marker.ADD
        marker.header.frame_id = 'world'
        # marker scale
        marker.scale.x = 1.0

        # marker color
        marker.color.a = 1.0
        marker.color.r = 0.0
        marker.color.g = 0.0
        marker.color.b = 0.0

        # marker orientaiton
        marker.pose.orientation.x = 0.0
        marker.pose.orientation.y = 0.0
        marker.pose.orientation.z = 0.0
        marker.pose.orientation.w = 1.0

        # marker position
        marker.pose.position.x = 0.0
        marker.pose.position.y = 0.0
        marker.pose.position.z = 0.0

        P0 = Point()
        P11 = Point()
        P12 = Point()
        P13 = Point()
        P14 = Point()
        P16 = Point()
        P23 = Point()
        P24 = Point()
        P25 = Point()
        P26 = Point()
        P27 = Point()
        P28 = Point()
        P29 = Point()
        P30 = Point()
        P31 = Point()
        P32 = Point()

        P0.x = keypoints3d[0][0]
        P0.y = keypoints3d[0][1]
        P0.z = keypoints3d[0][2]

        P11.x = keypoints3d[11][0]
        P11.y = keypoints3d[11][1]
        P11.z = keypoints3d[11][2]

        P12.x = keypoints3d[12][0]
        P12.y = keypoints3d[12][1]
        P12.z = keypoints3d[12][2]

        P13.x = keypoints3d[13][0]
        P13.y = keypoints3d[13][1]
        P13.z = keypoints3d[13][2]

        P14.x = keypoints3d[14][0]
        P14.y = keypoints3d[14][1]
        P14.z = keypoints3d[14][2]

        P16.x = keypoints3d[16][0]
        P16.y = keypoints3d[16][1]
        P16.z = keypoints3d[16][2]

        P23.x = keypoints3d[23][0]
        P23.y = keypoints3d[23][1]
        P23.z = keypoints3d[23][2]

        P24.x = keypoints3d[24][0]
        P24.y = keypoints3d[24][1]
        P24.z = keypoints3d[24][2]

        P25.x = keypoints3d[25][0]
        P25.y = keypoints3d[25][1]
        P25.z = keypoints3d[25][2]

        P26.x = keypoints3d[26][0]
        P26.y = keypoints3d[26][1]
        P26.z = keypoints3d[26][2]

        P27.x = keypoints3d[27][0]
        P27.y = keypoints3d[27][1]
        P27.z = keypoints3d[27][2]

        P28.x = keypoints3d[28][0]
        P28.y = keypoints3d[28][1]
        P28.z = keypoints3d[28][2]

        P29.x = keypoints3d[29][0]
        P29.y = keypoints3d[29][1]
        P29.z = keypoints3d[29][2]

        P30.x = keypoints3d[30][0]
        P30.y = keypoints3d[30][1]
        P30.z = keypoints3d[30][2]

        P31.x = keypoints3d[31][0]
        P31.y = keypoints3d[31][1]
        P31.z = keypoints3d[31][2]

        P32.x = keypoints3d[32][0]
        P32.y = keypoints3d[32][1]
        P32.z = keypoints3d[32][2]

        marker.points = []

        marker.points.append(P12)
        marker.points.append(P14)

        marker.points.append(P14)
        marker.points.append(P16)

        marker.points.append(P11)
        marker.points.append(P13)

        marker.points.append(P0)
        marker.points.append(P0)

        marker.points.append(P12)
        marker.points.append(P11)

        marker.points.append(P24)
        marker.points.append(P23)

        marker.points.append(P12)
        marker.points.append(P24)

        marker.points.append(P11)
        marker.points.append(P23)

        marker.points.append(P24)
        marker.points.append(P26)

        marker.points.append(P26)
        marker.points.append(P28)

        marker.points.append(P23)
        marker.points.append(P25)

        marker.points.append(P25)
        marker.points.append(P27)

        marker.points.append(P28)
        marker.points.append(P30)

        marker.points.append(P28)
        marker.points.append(P32)

        marker.points.append(P27)
        marker.points.append(P29)

        marker.points.append(P27)
        marker.points.append(P31)

        while not rospy.is_shutdown():
            self.marker_pub.publish(marker)

        return


if __name__ == '__main__':
    # get topics to do hpe
    topics = sys.argv[1:-2]

    # start python wrapper
    rospy.init_node('hpe_skeleton')
    rate = rospy.Rate(5)
    rospy.loginfo("Publishing 3D human skeleton")
    rospy.loginfo("Topics: " + str(topics))
    node = HPE(topics)
    # cameras.append(node)
    while not rospy.is_shutdown():
        node.hpe_3d()
        rate.sleep()
