#!/usr/bin/env python3
import sys

# from openpose import pyopenpose as op

from pose_estimator_2d import openpose_estimator
from pose_estimator_2d.estimator_2d import Estimator2D
import cv2
import numpy as np
import os
from pathlib import Path
from bvh_skeleton import openpose_skeleton, h36m_skeleton, cmu_skeleton
import matplotlib.pyplot as plt
import functools
import operator
from natsort import natsorted

from utils.links import joint_idxs_op_mpi, links_op_mpi

folder = "../images/"


def filter_missing_value(keypoints_list, method='ignore'):
    # TODO: impletemd 'interpolate' method.
    """Filter missing value in pose list.
    Args:
        keypoints_list: Estimate result returned by 2d estimator. Missing value
        will be None.
        method: 'ignore' -> drop missing value.
    Return:
        Keypoints list without missing value.
    """

    result = []
    if method == 'ignore':
        result = [pose for pose in keypoints_list if pose is not None]
    else:
        raise ValueError(f'{method} is not a valid method.')

    return result


class HPE():
    def __init__(self, topic, debug=False, save_images=True):
        self.debug = debug
        self.image_raw = None
        self.image_skeleton = None
        self.topic = topic
        self.save_images = save_images
        self.dataset_name = 'mpi-inf-3dhp/S1/Seq1/'

        params = {'model_folder': '/home/daniela/openpose/models/', 'render_pose': 0, 'model_pose': 'MPI',
                  'net_resolution': '320x176'}
        self.opWrapper = op.WrapperPython()
        self.opWrapper.configure(params)
        self.opWrapper.start()
        print('Starting keypoint retreival for ' + str(self.topic))

        self.empty_skeleton = []
        list = []
        for i in range(25):
            list.append([0, 0, 0])
        self.empty_skeleton.append(list)

    def scan_all_images(self):
        folder_dir = '/home/daniela/catkin_ws/src/hpe/images/' + self.dataset_name + 'imageSequence/' + self.topic
        i = 0
        keypoint_list = []

        for image_name in natsorted(os.listdir(folder_dir)):
            image_path = os.path.join(folder_dir, image_name)
            image = cv2.imread(image_path)

            if image is not None:
                keypoints = self.hpe(image, image_name)
            keypoint_list.append(keypoints[0, :, :3])
            # print(keypoint_list[0, :, :3])
            # exit(0)
            i = i + 1

        keypoint_list = filter_missing_value(
            keypoints_list=keypoint_list,
            method='ignore'
        )
        # print(keypoint_list)
        # exit(0)
        save_file_folder = '/home/daniela/catkin_ws/src/hpe/images/' + self.dataset_name + 'cache/' + str(
        self.topic) + '/2d_pose_op.npy'

        np.save(save_file_folder, keypoint_list, allow_pickle=True)
        print('Saved file with keypoint values for ' + str(self.topic) + ' with success!')

    def hpe(self, img, image_name):
        datum = op.Datum()
        op_skel = openpose_skeleton.OpenPoseSkeleton()
        datum.cvInputData = img
        self.opWrapper.emplaceAndPop(op.VectorDatum([datum]))

        # datum attributes:
        #  'cameraExtrinsics', 'cameraIntrinsics', 'cameraMatrix', 'cvInputData', 'cvOutputData', 'cvOutputData3D',
        #  'elementRendered', 'faceHeatMaps', 'faceKeypoints', 'faceKeypoints3D', 'faceRectangles', 'frameNumber',
        #  'handHeatMaps', 'handKeypoints', 'handKeypoints3D', 'handRectangles', 'id', 'inputNetData', 'name',
        #  'netInputSizes', 'netOutputSize', 'outputData', 'poseCandidates', 'poseHeatMaps', 'poseIds', 'poseKeypoints',
        #  'poseKeypoints3D', 'poseNetOutput', 'poseScores', 'scaleInputToNetInputs', 'scaleInputToOutput',
        #  'scaleNetToOutput', 'subId', 'subIdMax'

        keypoints = datum.poseKeypoints
        if keypoints is None:
            keypoints = np.array(self.empty_skeleton)

        print(keypoints[0, :, :3])
        exit(0)

        if self.save_images:
            self.vis_2d_keypoints(keypoints=keypoints[0], img=img, image_name=image_name, skeleton=op_skel,
                                  kp_thresh=0.4)

        # print(len(keypoint_list))
        # exit(0)
        return keypoints

    def vis_2d_keypoints(self, keypoints, img, image_name, skeleton, kp_thresh, alpha=0.7, show_name=False):
        # Convert from plt 0-1 RGBA colors to 0-255 BGR colors for opencv.

        if keypoints is not None:

            cmap = plt.get_cmap('rainbow')
            colors = [cmap(i) for i in np.linspace(0, 1, len(keypoints))]
            colors = [(c[2] * 255, c[1] * 255, c[0] * 255) for c in colors]

            for joint_key, joint in joint_idxs_op_mpi.items():
                idx = joint['idx']
                point = keypoints[idx]
                x, y = int(point[0]), int(point[1])
                color = colors[idx]

                pts_in_image = (x, y)

                img = cv2.circle(img, pts_in_image, 0, color=color, thickness=10)

            for link_name, link in links_op_mpi.items():
                joint0 = link['parent']
                joint1 = link['child']
                idx_0 = joint_idxs_op_mpi[joint0]['idx']
                idx_1 = joint_idxs_op_mpi[joint1]['idx']

                point0 = keypoints[idx_0]
                point1 = keypoints[idx_1]

                point_0=(int(point0[0]), int(point0[1]))
                point_1=(int(point1[0]), int(point1[1]))

                if point_0 != (0,0) and point_1 != (0,0):
                    cv2.line(img, point_0, point_1, (128, 128, 128), 3)


            new_width = int(img.shape[1] / 2)
            new_height = int(img.shape[0] / 2)
            img = cv2.resize(img, (new_width, new_height))
            cv2.imshow("Window", img)
            cv2.waitKey(0)

            # if not self.debug:
            #     output_folder = "/home/daniela/catkin_ws/src/hpe/images/" + self.dataset_name + "cache/" + self.topic + '/' + image_name
            #     cv2.imwrite(output_folder, vis_result)
            #
            # else:
            #     fig = plt.figure()
            #     ax1 = fig.add_subplot(1, 2, 1)
            #     connections = [[1, 2], [2, 3], [3, 4], [1, 5], [5, 6], [6, 7],
            #                    [1, 8], [8, 9], [9, 10], [10, 11], [11, 24], [11, 22], [22, 23], [8, 12], [12, 13], [13, 14],
            #                    [14, 21], [14, 19], [19, 20]]
            #
            #     ax2 = fig.add_subplot(1, 2, 2)
            #     # ax2.set_aspect('equal', 'box')
            #     ax2.title.set_text('skeleton')
            #     ax2.set_xlim(0, 1200)
            #     ax2.set_ylim(800, 0)
            #     pose_2d = keypoint_list
            #
            #     if pose_2d is not None:
            #         for _c in connections:
            #             ax2.plot([pose_2d[_c[0]][0], pose_2d[_c[1]][0]], [pose_2d[_c[0]][1], pose_2d[_c[1]][1]], c='red')
            #         idx = 0
            #         for point in pose_2d:
            #             if idx not in [0, 15, 16, 17, 18] and (point[0] != 0 and point[1] != 0):
            #                 ax2.scatter(point[0], point[1], c='b')
            #                 ax2.text(point[0], point[1], str(idx))
            #                 ax2.text(point[0] + 10, point[1] + 5, str(round(point[2], 2)), c='r')
            #             idx = idx + 1
            #
            #     img = np.asarray(vis_result)
            #     plt.figtext(0.02, 0.02, str(image_name), fontsize=14)
            #     ax1.imshow(img)
            #     plt.show()
            #     plt.pause(0.01)
        return


def main():
    cameras = ['camera_0', 'camera_1', 'camera_2', 'camera_4', 'camera_5', 'camera_6', 'camera_7', 'camera_8']
    # cameras = ['camera_1', 'camera_2', 'camera_4', 'camera_5', 'camera_6', 'camera_7', 'camera_8']

    # cameras = ['camera_1', 'camera_3']
    # cameras = ['camera_2_sample']
    # cameras = ['camera_5']
    # cameras = ['tests']
    # cameras = ['54138969', '55011271', '58860488', '60457274']

    for camera in cameras:
        cam = HPE(camera)
        cam.scan_all_images()


if __name__ == "__main__":
    main()
