# from . import camera
from pose_estimator_2d import openpose_estimator
from pose_estimator_3d import estimator_3d
from scripts.utils import camera, vis
# from utils import smooth, vis, camera
from bvh_skeleton import openpose_skeleton, h36m_skeleton, cmu_skeleton
import cv2
import numpy as np
import os
from pathlib import Path
import importlib
from scripts.utils.transforms import getTransform, getChain, getAggregateTransform
from urdf_parser_py.urdf import URDF
from atom_core.naming import generateKey
from tf.transformations import quaternion_from_matrix, quaternion_from_euler
from IPython.display import HTML


class HPE_3D():
    def __init__(self, topic):
        self.image = None
        self.topic = topic
        folder = "/home/daniela/catkin_ws/src/hpe/images/cache/" + self.topic + '/2d_pose.npy'
        file = "xacros/well_optimized.urdf.xacro"
        # pose2d_file = Path(folder / '2d_pose.npy')
        self.pose2d = np.load(folder, allow_pickle=True)
        self.transform_dict = self.get_transform_tree_dict(file)
        self.pose3d = None
        self.camera_link = str(topic) + '_rgb_optical_frame'
        # print(self.pose2d)

    def extract_3d_skeleton(self):
        img_width, img_height = None, None

        importlib.reload(estimator_3d)
        e3d = estimator_3d.Estimator3D(
            config_file='../models/openpose_video_pose_243f/video_pose.yaml',
            checkpoint_file='../models/openpose_video_pose_243f/best_58.58.pth'
        )

        folder = "/home/daniela/catkin_ws/src/hpe/images/cache/" + self.topic + '/0.png'
        img = cv2.imread(folder)
        # print(img)
        # cv2.imshow(img)
        # cv2.waitKey(0)
        img_height = img.shape[0]
        img_width = img.shape[1]
        # print(img_height, img_width)
        # print(self.pose2d.dtype)

        # while True:
        #     ret, frame = cap.read()
        #     if not ret:
        #         break
        #     img_height = frame.shape[0]
        #     img_width = frame.shape[1]

        # pose2d = np.load(pose2d_file)
        # print(self.pose2d.shape)
        pose3d = e3d.estimate(self.pose2d, image_width=img_width, image_height=img_height)
        self.pose3d = pose3d
        return pose3d

    def convert_pose_to_world(self):
        T = getTransform('world', self.camera_link, self.transform_dict)

        R = T[0:3, 0:3]
        Trans=np.transpose(T[0:3,3])
        print(R, Trans)

        pose3d_world = camera.camera2world(pose=self.pose3d, R=R, T=Trans)
        pose3d_world[:, :, 2] -= np.min(pose3d_world[:, :, 2])  # rebase the height
        return pose3d_world

    def visualize(self, pose3d_world):
        gif_file = '/home/daniela/catkin_ws/src/hpe/images/cache/3d_pose_500' + self.topic + '.gif'  # output format can be .gif or .mp4
        h36m_skel = h36m_skeleton.H36mSkeleton()
        ani = vis.vis_3d_keypoints_sequence(
            keypoints_sequence=pose3d_world[0:500],
            skeleton=h36m_skel,
            azimuth=0,
            fps=10,
            output_file=gif_file
        )
        HTML(ani.to_jshtml())


    def get_transform_tree_dict(self,file):
        xml_robot = URDF.from_xml_file(file)
        dict = {}

        for joint in xml_robot.joints:
            child = joint.child
            parent = joint.parent
            xyz = joint.origin.xyz
            rpy = joint.origin.rpy
            key = generateKey(parent, child)

            dict[key] = {}
            dict[key]['child'] = child
            dict[key]['parent'] = parent
            dict[key]['trans'] = xyz
            dict[key]['quat'] = list(quaternion_from_euler(rpy[0], rpy[1], rpy[2], axes='sxyz'))

        return dict
        # T = getTransform('world', 'lidar_1_base_link', dict)


def visualize_n_cameras(poses):
    gif_file = '/home/daniela/catkin_ws/src/hpe/images/cache/3d_pose_500_multi_cam_colors.gif'  # output format can be .gif or .mp4
    h36m_skel = h36m_skeleton.H36mSkeleton()
    ani = vis.vis_3d_keypoints_sequence_multi_cam(
        keypoints_sequences=poses,
        skeleton=h36m_skel,
        azimuth=0,
        fps=10,
        output_file=gif_file
    )
    HTML(ani.to_jshtml())

def main():
    # cameras = ['camera_2']
    cameras = ['camera_2', 'camera_3', 'camera_4']
    poses=[]

    for camera in cameras:
        cam = HPE_3D(camera)
        cam.extract_3d_skeleton()
        pose3d=cam.convert_pose_to_world()
        # cam.visualize(pose3d)
        poses.append((pose3d[0:500]))
        # cam.hpe()

    visualize_n_cameras(poses)

if __name__ == "__main__":
    main()
