from bvh_skeleton import openpose_skeleton, h36m_skeleton, cmu_skeleton
import cv2
import numpy as np
import os
from pathlib import Path
import importlib
import yaml
from yaml.loader import SafeLoader
import numpy as np
from scipy.spatial import Delaunay
import matplotlib.pyplot as plt
from matplotlib import cm
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.animation as animation
from scipy import linalg
from IPython.display import HTML
from PIL import Image

from utils import camera, vis
from utils.transforms import getTransform, getChain, getAggregateTransform
from urdf_parser_py.urdf import URDF
from atom_core.naming import generateKey
from tf.transformations import quaternion_from_matrix, quaternion_from_euler


class Triangulate():
    def __init__(self, topics):
        self.topics = topics
        file = "xacros/well_optimized.urdf.xacro"
        self.transform_dict = self.get_transform_tree_dict(file)

        self.topics_dict = {}
        self.topic_0 = 'world'
        n = 0
        for topic in self.topics:
            self.topics_dict[topic] = {}
            self.topics_dict[topic]["folder"] = "/home/daniela/catkin_ws/src/hpe/images/real/cache/" + topic + '/2d_pose.npy'
            self.topics_dict[topic]["camera_link"] = str(topic) + '_rgb_optical_frame'
            self.topics_dict[topic]["pose_2d"] = np.load(self.topics_dict[topic]["folder"], allow_pickle=True)
            yaml_folder = "/home/daniela/catkin_ws/src/hpe/images/real/" + topic + '.yaml'
            with open(yaml_folder) as f:
                data = yaml.load(f, Loader=SafeLoader)
            self.topics_dict[topic]["camera_matrix"] = np.reshape(data['camera_matrix']['data'], (3, 3))
            # print(self.topics_dict[topic]["camera_matrix"])
            # print(self.topics_dict[topic]["pose_2d"][0,:])

            if n == 0:
                self.topics_dict[topic]['R'] = np.eye(3)
                self.topics_dict[topic]['T'] = np.zeros(3)
                self.topic_0 = topic

            else:
                self.topics_dict[topic]['R'], self.topics_dict[topic]['T'] = self.get_rotation_and_translation(
                    self.topics_dict[topic]["camera_link"], self.topics_dict[self.topic_0]["camera_link"])
            # self.topics_dict[topic]['R'], self.topics_dict[topic]['T'] = self.get_rotation_and_translation(
            #     self.topics_dict[topic]["camera_link"])
            # print(self.topics_dict[topic]['R'], self.topics_dict[topic]['T'])
            self.topics_dict[topic]['P'] = self.get_params(self.topics_dict[topic]['R'], self.topics_dict[topic]['T'],
                                                           self.topics_dict[topic]["camera_matrix"])
            n = 1

    def triangulate(self, topic1, topic2):
        accumulated_skeletons = []
        for i in range(min(self.topics_dict[topic2]["pose_2d"].shape[0], self.topics_dict[topic1]["pose_2d"].shape[0])):
            skeleton = []
            for n in range(25):
                point1 = (self.topics_dict[topic1]["pose_2d"][i, n])
                point2 = (self.topics_dict[topic2]["pose_2d"][i, n])
                if (point1[0] == 0 and point1[1] == 0) or (point2[0] == 0 and point2[0] == 0):
                    point_3d = np.zeros([1, 3])[0]
                else:
                    point_3d = self.tringulate_between_pairs(self.topics_dict[topic1]['P'],
                                                             self.topics_dict[topic2]['P'],
                                                             point1, point2)
                # print(np.zeros([1,3])[0])
                skeleton.append(point_3d)
            accumulated_skeletons.append(skeleton)
        return accumulated_skeletons

    def get_rotation_and_translation(self, from_camera_link, to_camera_link):
        T = getTransform(from_camera_link, to_camera_link, self.transform_dict)
        print(T)

        R = T[0:3, 0:3]
        Trans = np.transpose(T[0:3, 3])
        print(R, Trans)

        return (R, Trans)

    def convert_pose_to_world(self, camera_link):
        T = getTransform('world', camera_link, self.transform_dict)

        R = T[0:3, 0:3]
        Trans = np.transpose(T[0:3, 3])
        print(R, Trans)

        pose3d_world = camera.camera2world(pose=self.pose3d, R=R, T=Trans)
        pose3d_world[:, :, 2] -= np.min(pose3d_world[:, :, 2])  # rebase the height
        return pose3d_world

    def get_transform_tree_dict(self, file):
        xml_robot = URDF.from_xml_file(file)
        dict = {}

        for joint in xml_robot.joints:
            child = joint.child
            parent = joint.parent
            xyz = joint.origin.xyz
            rpy = joint.origin.rpy
            key = generateKey(parent, child)

            dict[key] = {}
            dict[key]['child'] = child
            dict[key]['parent'] = parent
            dict[key]['trans'] = xyz
            dict[key]['quat'] = list(quaternion_from_euler(rpy[0], rpy[1], rpy[2], axes='sxyz'))

        return dict

    def get_params(self, R, T, mtx):
        # RT matrix for C1 is identity.
        # print(R.shape)
        # print(np.reshape(T, (3,1)))
        # print(T[0])
        RT = np.concatenate([R, np.reshape(T, (3, 1))], axis=-1)
        P = mtx @ RT  # projection matrix for C1
        print('projection')
        print(P)
        return P

    def tringulate_between_pairs(self, P1, P2, point1, point2):  # P - projection matrix
        A = [point1[1] * P1[2, :] - P1[1, :],
             P1[0, :] - point1[0] * P1[2, :],
             point2[1] * P2[2, :] - P2[1, :],
             P2[0, :] - point2[0] * P2[2, :]
             ]
        A = np.array(A).reshape((4, 4))
        # print('A: ')
        # print(A)

        B = A.transpose() @ A
        U, s, Vh = linalg.svd(B, full_matrices=False)

        # print('Triangulated point: ')
        # print(Vh[3, 0:3] / Vh[3, 3])
        return Vh[3, 0:3] / Vh[3, 3]
        # return

    def show_skeletons(self, accumulated_skeletons):
        from mpl_toolkits.mplot3d import Axes3D
        idx = 0
        for p3ds in accumulated_skeletons:
            # print(len(p3ds))
            fig = plt.figure()
            ax = fig.add_subplot(3, 2, 1, projection='3d')
            ax.view_init(90, 100)
            ax.set_title('Triangulation')
            #
            # connections = [[0, 1], [1, 2], [2, 3], [3, 4], [1, 5], [5, 6], [6, 7], [1, 8], [1, 9], [2, 8], [5, 9], [8, 9],
            #                [0, 10], [0, 11]]

            # connections = [[17, 15], [15, 0], [0, 16], [16, 18], [0, 1], [1, 2], [2, 3], [3, 4], [1, 5], [5, 6], [6, 7],
            #                [1, 8], [8, 9], [9, 10], [10, 11], [11, 24], [11, 22], [22, 23], [8, 12], [12, 13], [13, 14],
            #                [14, 21], [14, 19], [19, 20]]
            connections = [[1, 2], [2, 3], [3, 4], [1, 5], [5, 6], [6, 7],
                           [1, 8], [8, 9], [9, 10], [10, 11], [11, 24], [11, 22], [22, 23], [8, 12], [12, 13], [13, 14],
                           [14, 21], [14, 19], [19, 20]]
            for _c in connections:
                # print(p3ds[_c[0]][0])
                # print(p3ds[_c[1]])
                if p3ds[_c[0]][0] != 0 and p3ds[_c[1]][0] != 0 and p3ds[_c[0]][1] != 0 and p3ds[_c[1]][1] != 0 and \
                        p3ds[_c[0]][2] != 0 and p3ds[_c[1]][2] != 0:
                    ax.plot(xs=[p3ds[_c[0]][0], p3ds[_c[1]][0]], ys=[p3ds[_c[0]][1], p3ds[_c[1]][1]],
                            zs=[p3ds[_c[0]][2], p3ds[_c[1]][2]], c='red')
            i = 0
            plt.figtext(0.02, 0.02, 'frame: ' + str(idx), fontsize=14)
            for point in p3ds:
                # print(point)
                if point[0] != 0 and point[1] != 0 and point[2] != 0:
                    if (i not in [0, 15, 16, 17, 18]):
                        ax.scatter(point[0], point[1], point[2], c='b')
                        ax.text(point[0], point[1], point[2], str(i))
                i = i + 1

            n = 3
            for topic in self.topics:
                ax2 = fig.add_subplot(3, 2, n)
                # ax2.set_aspect('equal', 'box')
                ax2.title.set_text(topic)
                ax2.set_xlim(0, 1200)
                ax2.set_ylim(800, 0)
                pose_2d = self.topics_dict[topic]["pose_2d"][idx]

                for _c in connections:
                    if pose_2d[_c[0]][0] != 0 and pose_2d[_c[1]][0] != 0 and pose_2d[_c[0]][
                        1] != 0 and pose_2d[_c[1]][1] != 0:
                        ax2.plot([pose_2d[_c[0]][0], pose_2d[_c[1]][0]], [pose_2d[_c[0]][1], pose_2d[_c[1]][1]],
                                 c='red')
                i = 0
                for point in pose_2d:
                    if i not in [0, 15, 16, 17, 18] and (point[0] != 0 and point[1] != 0):
                        # ax2.axes.invert_yaxis()
                        ax2.scatter(point[0], point[1], c='b')
                        ax2.text(point[0], point[1], str(i))
                        # ax2.invert_yaxis()
                    i = i + 1

                folder = "/home/daniela/catkin_ws/src/hpe/images/real/cache/" + topic + '/' + str(idx) + '.png'
                ax3 = fig.add_subplot(3, 2, n + 1)
                img = np.asarray(Image.open(folder))
                ax3.imshow(img)

                n = n + 2
            idx = idx + 1

            # n=2
            # for topic in self.topics:
            #     ax2 = fig.add_subplot(3, 1, n)
            #     pose_2d = self.topics_dict[topic]["pose_2d"]
            #     for _c in connections:
            #         ax2.plot(xs=[pose_2d[_c[0]][0], pose_2d[_c[1]][0]], ys=[pose_2d[_c[0]][1], pose_2d[_c[1]][1]], c='red')
            #     i=0
            #     for point in pose_2d:
            #         if i not in [0, 15, 16, 17, 18]:
            #             ax2.scatter(point[0], point[1], c='b')
            #             ax2.text(point[0], point[1], str(i))
            #         i = i + 1
            #
            #     n=n+1

            plt.show()
            plt.pause(0.01)
        return

    def visualize(self, pose3d_world):
        gif_file = '/home/daniela/catkin_ws/src/hpe/images/cache/tringulated_pose.mp4'  # output format can be .gif or .mp4
        h36m_skel = h36m_skeleton.H36mSkeleton()
        ani = vis.vis_3d_keypoints_sequence(
            keypoints_sequence=pose3d_world[0:500],
            skeleton=h36m_skel,
            azimuth=0,
            fps=10,
            output_file=gif_file
        )
        HTML(ani.to_jshtml())

    def combine_skeletons(self):
        return


def main():
    # cameras = ['camera_2']
    # cameras = ['camera_2', 'camera_3', 'camera_4']
    cameras = ['camera_2', 'camera_3']
    Tri = Triangulate(cameras)
    accumulated_skeletons = Tri.triangulate("camera_2", "camera_3")
    Tri.show_skeletons(accumulated_skeletons)
    # Tri.visualize(accumulated_skeletons)


if __name__ == "__main__":
    main()
