# from . import camera
from pose_estimator_2d import openpose_estimator
from pose_estimator_3d import estimator_3d
from scripts.utils import camera, vis
# from utils import smooth, vis, camera
from bvh_skeleton import openpose_skeleton, h36m_skeleton, cmu_skeleton
import cv2
import numpy as np
import os
from pathlib import Path
import importlib
import matplotlib.pyplot as plt
from scripts.utils.transforms import getTransform, getChain, getAggregateTransform
from urdf_parser_py.urdf import URDF
from atom_core.naming import generateKey
from tf.transformations import quaternion_from_matrix, quaternion_from_euler
from IPython.display import HTML
import yaml
from yaml.loader import SafeLoader
from image_geometry import PinholeCameraModel
from PIL import Image


class HPE_3D():
    def __init__(self, topic):
        self.image = None
        self.topic = topic
        folder = "/home/daniela/catkin_ws/src/hpe/images/cache/with_depth/" + self.topic + '/2d_pose.npy'
        file = "xacros/well_optimized.urdf.xacro"
        # pose2d_file = Path(folder / '2d_pose.npy')
        self.pose2d = np.load(folder, allow_pickle=True)
        self.transform_dict = self.get_transform_tree_dict(file)
        self.pose3d = None
        self.camera_link = str(topic) + '_rgb_optical_frame'

        yaml_folder = "/home/daniela/catkin_ws/src/hpe/images/" + topic + '.yaml'
        with open(yaml_folder) as f:
            data = yaml.load(f, Loader=SafeLoader)
        self.camera_matrix = np.reshape(data['camera_matrix']['data'], (3, 3))

    def get_depth(self, x_pix, y_pix, idx):
        path = "/home/daniela/catkin_ws/src/hpe/images/with_depth/" + self.topic + '_depth/' + str(idx) + '.png'
        img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
        print("Image index: " + str(idx))
        print(img.shape[0], img.shape[1])
        print(y_pix, x_pix)
        if x_pix > img.shape[1] - 1:
            x_pix = img.shape[1] - 1
        if y_pix > img.shape[0] - 1:
            y_pix = img.shape[0] - 1
        print(y_pix, x_pix)

        depth = img[round(y_pix), round(x_pix)]
        # print("Depth: " + str(depth))

        return depth

    def get_params(self):
        fx = self.camera_matrix[0][0]
        fy = self.camera_matrix[1][1]
        cx = self.camera_matrix[0][2]
        cy = self.camera_matrix[1][2]
        # print(self.camera_matrix)
        # print(fx,fy,cx,cy)
        return fx, fy, cx, cy

    def get_skeletons_with_depth(self):
        fx, fy, cx, cy = self.get_params()

        idx = 0
        accumulated_skeletons = []
        for sklt in self.pose2d:
            skeleton = []
            for point in sklt:
                # print(point[0],point[1])
                depth = self.get_depth(point[0], point[1], idx)
                x, y, z = self.convert_from_uvd(cx, cy, fx, fy, point[0], point[1], depth)
                # print(depth)
                skeleton.append([x, y, z])
            accumulated_skeletons.append(skeleton)
            self.show_depth_with_skeleton_2d(idx, skeleton)
            print(idx)
            idx = idx + 1

        # x, y, z = self.convert_from_uvd(c_x, c_y, f_x, f_y)
        return accumulated_skeletons

    def homographyFromTransform(self, T):
        H = np.zeros((3, 3), np.float)

        H[0, 0] = T[0, 0]
        H[0, 1] = T[0, 1]
        H[0, 2] = T[0, 3]

        H[1, 0] = T[1, 0]
        H[1, 1] = T[1, 1]
        H[1, 2] = T[1, 3]

        H[2, 0] = T[2, 0]
        H[2, 1] = T[2, 1]
        H[2, 2] = T[2, 3]

        return H

    # def project_pixel_coordinates_from_rgb_to_depth(self):
    #     w_T_ss_cs = getTransform('world',  str(self.topic) + '_rgb_optical_frame',
    #                              self.transform_dict)
    #     w_T_st_ct = getTransform('world', self.topic + '_depth', self.transform_dict)
    #     st_T_p_ct = np.dot(np.linalg.inv(w_T_st_ct), np.dot(w_T_ss_cs, ss_T_p_cs))
    #
    #     # print('st_T_p =\n' + str(st_T_p))
    #
    #     # -------------------------------------------------------------
    #     # STEP 4: Compute homography matrices for both sensors and the combined homography
    #     # -------------------------------------------------------------
    #     ss_H_p = np.dot(K_s, self.homographyFromTransform(ss_T_p_cs))
    #     st_H_p = np.dot(K_t, self.homographyFromTransform(st_T_p_ct))
    #     st_H_ss = np.dot(st_H_p, np.linalg.inv(ss_H_p))  # combined homography
    #
    #     ucorners_s_proj_to_t = np.dot(st_H_ss, ucorners_s)
    #
    #     ucorners_s_proj_to_t = ucorners_s_proj_to_t / np.tile(ucorners_s_proj_to_t[2, :], (3, 1))

    def convert_from_uvd(self, cx, cy, fx, fy, xpix, ypix, d):
        # From http://www.open3d.org/docs/0.7.0/python_api/open3d.geometry.create_point_cloud_from_depth_image.html
        print('Pixel coordinates: ' + str(xpix) + ' ' + str(ypix))
        x_over_z = (xpix - cx) / fx
        y_over_z = (ypix - cy) / fy
        z = d
        x = x_over_z * z
        y = y_over_z * z
        print("X: " + str(x) + " Y: " + str(y) + " Z: " + str(z))
        # print(x,y,z)
        return x, y, z

    def show_depth_with_skeleton_2d(self, idx, skeleton3d):
        path_depth = "/home/daniela/catkin_ws/src/hpe/images/with_depth/" + self.topic + '_depth/' + str(idx) + '.png'
        path_rgb = "/home/daniela/catkin_ws/src/hpe/images/cache/with_depth/" + self.topic + '/' + str(idx) + '.png'

        fig = plt.figure()
        ax1 = fig.add_subplot(1, 3, 1)

        img_depth = np.asarray(Image.open(path_depth))
        ax1.imshow(img_depth)
        pose_2d = self.pose2d[idx]
        plt.figtext(0.02, 0.02, str(idx), fontsize=14)

        connections = [[1, 2], [2, 3], [3, 4], [1, 5], [5, 6], [6, 7],
                       [1, 8], [8, 9], [9, 10], [10, 11], [11, 24], [11, 22], [22, 23], [8, 12], [12, 13], [13, 14],
                       [14, 21], [14, 19], [19, 20]]

        for _c in connections:
            if pose_2d[_c[0]][0] != 0 and pose_2d[_c[1]][0] != 0 and pose_2d[_c[0]][
                1] != 0 and pose_2d[_c[1]][1] != 0:
                ax1.plot([pose_2d[_c[0]][0], pose_2d[_c[1]][0]], [pose_2d[_c[0]][1], pose_2d[_c[1]][1]],
                         c='red')
        i = 0
        for point in pose_2d:
            if i not in [0, 15, 16, 17, 18] and (point[0] != 0 and point[1] != 0):
                # ax2.axes.invert_yaxis()
                ax1.scatter(point[0], point[1], c='b')
                ax1.text(point[0], point[1], str(i))
                # ax2.invert_yaxis()
            i = i + 1

        ax2 = fig.add_subplot(1, 3, 2)
        img_depth = np.asarray(Image.open(path_rgb))
        ax2.imshow(img_depth)

        ax3 = fig.add_subplot(1, 3, 3, projection='3d')
        p3ds=skeleton3d
        for _c in connections:
            # print(p3ds[_c[0]][0])
            # print(p3ds[_c[1]])
            if p3ds[_c[0]][0] != 0 and p3ds[_c[1]][0] != 0 and p3ds[_c[0]][1] != 0 and p3ds[_c[1]][1] != 0 and \
                    p3ds[_c[0]][2] != 0 and p3ds[_c[1]][2] != 0:
                ax3.plot(xs=[p3ds[_c[0]][0], p3ds[_c[1]][0]], ys=[p3ds[_c[0]][1], p3ds[_c[1]][1]],
                        zs=[p3ds[_c[0]][2], p3ds[_c[1]][2]], c='red')
        i = 0
        for point in p3ds:
            # print(point)
            if point[0] != 0 and point[1] != 0 and point[2] != 0:
                if i not in [0, 15, 16, 17, 18]:
                    ax3.scatter(point[0], point[1], point[2], c='b')
                    ax3.text(point[0], point[1], point[2], str(i))
            i = i + 1
        plt.show()
        plt.pause(0.01)

    def show_skeletons(self, accumulated_skeletons):
        from mpl_toolkits.mplot3d import Axes3D
        idx = 0
        for p3ds in accumulated_skeletons:
            # print(len(p3ds))
            fig = plt.figure()
            ax = fig.add_subplot(2, 1, 1, projection='3d')
            ax.view_init(90, 100)
            ax.set_title('Triangulation')
            #
            # connections = [[0, 1], [1, 2], [2, 3], [3, 4], [1, 5], [5, 6], [6, 7], [1, 8], [1, 9], [2, 8], [5, 9], [8, 9],
            #                [0, 10], [0, 11]]

            # connections = [[17, 15], [15, 0], [0, 16], [16, 18], [0, 1], [1, 2], [2, 3], [3, 4], [1, 5], [5, 6], [6, 7],
            #                [1, 8], [8, 9], [9, 10], [10, 11], [11, 24], [11, 22], [22, 23], [8, 12], [12, 13], [13, 14],
            #                [14, 21], [14, 19], [19, 20]]
            connections = [[1, 2], [2, 3], [3, 4], [1, 5], [5, 6], [6, 7],
                           [1, 8], [8, 9], [9, 10], [10, 11], [11, 24], [11, 22], [22, 23], [8, 12], [12, 13], [13, 14],
                           [14, 21], [14, 19], [19, 20]]
            for _c in connections:
                # print(p3ds[_c[0]][0])
                # print(p3ds[_c[1]])
                ax.plot(xs=[p3ds[_c[0]][0], p3ds[_c[1]][0]], ys=[p3ds[_c[0]][1], p3ds[_c[1]][1]],
                        zs=[p3ds[_c[0]][2], p3ds[_c[1]][2]], c='red')
            i = 0
            for point in p3ds:
                # print(point)
                if i not in [0, 15, 16, 17, 18]:
                    ax.scatter(point[0], point[1], point[2], c='b')
                    ax.text(point[0], point[1], point[2], str(i))
                i = i + 1

            ax2 = fig.add_subplot(2, 1, 2)
            # ax2.set_aspect('equal', 'box')
            ax2.title.set_text(self.topic)
            ax2.set_xlim(0, 1200)
            ax2.set_ylim(800, 0)
            pose_2d = self.pose2d[idx]
            for _c in connections:
                ax2.plot([pose_2d[_c[0]][0], pose_2d[_c[1]][0]], [pose_2d[_c[0]][1], pose_2d[_c[1]][1]], c='red')
            i = 0
            for point in pose_2d:
                if i not in [0, 15, 16, 17, 18]:
                    # ax2.axes.invert_yaxis()
                    ax2.scatter(point[0], point[1], c='b')
                    ax2.text(point[0], point[1], str(i))
                    # ax2.invert_yaxis()
                i = i + 1

            idx = idx + 1

            plt.show()
            plt.pause(0.01)
        return

    def convert_pose_to_world(self):
        T = getTransform('world', self.camera_link, self.transform_dict)

        R = T[0:3, 0:3]
        Trans = np.transpose(T[0:3, 3])
        # print(R, Trans)

        pose3d_world = camera.camera2world(pose=self.pose3d, R=R, T=Trans)
        pose3d_world[:, :, 2] -= np.min(pose3d_world[:, :, 2])  # rebase the height
        return pose3d_world

    def get_transform_tree_dict(self, file):
        xml_robot = URDF.from_xml_file(file)
        dict = {}

        for joint in xml_robot.joints:
            child = joint.child
            parent = joint.parent
            xyz = joint.origin.xyz
            rpy = joint.origin.rpy
            key = generateKey(parent, child)

            dict[key] = {}
            dict[key]['child'] = child
            dict[key]['parent'] = parent
            dict[key]['trans'] = xyz
            dict[key]['quat'] = list(quaternion_from_euler(rpy[0], rpy[1], rpy[2], axes='sxyz'))

        return dict


def main():
    # cameras = ['camera_1']
    # cameras = ['camera_2', 'camera_3', 'camera_4']
    # poses = []
    cam = HPE_3D('camera_1')
    acummulated_skeleton = cam.get_skeletons_with_depth()
    # print(acummulated_skeleton)
    cam.show_skeletons(acummulated_skeleton)

    # for camera in cameras:
    #     cam = HPE_3D(camera)
    #     cam.extract_3d_skeleton()
    #     pose3d = cam.convert_pose_to_world()
    #     # cam.visualize(pose3d)
    #     poses.append((pose3d[0:500]))
    #     # cam.hpe()
    #
    # visualize_n_cameras(poses)


if __name__ == "__main__":
    main()
