#!/usr/bin/env python3
import sys

from openpose import pyopenpose as op

from pose_estimator_2d import openpose_estimator
from pose_estimator_2d.estimator_2d import Estimator2D
import cv2
import numpy as np
import os
from pathlib import Path
from bvh_skeleton import openpose_skeleton, h36m_skeleton, cmu_skeleton
import matplotlib.pyplot as plt
import functools
import operator
from natsort import natsorted

folder = "../images/"


def filter_missing_value(keypoints_list, method='ignore'):
    # TODO: impletemd 'interpolate' method.
    """Filter missing value in pose list.
    Args:
        keypoints_list: Estimate result returned by 2d estimator. Missing value
        will be None.
        method: 'ignore' -> drop missing value.
    Return:
        Keypoints list without missing value.
    """

    result = []
    if method == 'ignore':
        result = [pose for pose in keypoints_list if pose is not None]
    else:
        raise ValueError(f'{method} is not a valid method.')

    return result


class HPE():
    def __init__(self, topic, debug=False):
        self.debug = debug
        self.image_raw = None
        self.image_skeleton = None
        self.topic = topic
        self.dataset_name = 'human36m/processed/S1/Directions-1/'

        params = {'model_folder': '/home/daniela/openpose/models/', 'render_pose': 0, 'model_pose': 'BODY_25',
                  'net_resolution': '320x176'}
        self.opWrapper = op.WrapperPython()
        self.opWrapper.configure(params)
        self.opWrapper.start()
        print('Starting keypoint retreival for ' + str(self.topic))

        self.empty_skeleton = []
        list = []
        for i in range(25):
            list.append([0, 0, 0])
        self.empty_skeleton.append(list)

    def scan_all_images(self):
        # folder_dir = '/home/daniela/catkin_ws/src/hpe/images/' + self.topic

        folder_dir = '/home/daniela/catkin_ws/src/hpe/images/' + self.dataset_name + 'imageSequence/' + self.topic

        # images = Path(folder_dir).glob('*.png')
        i = 0
        keypoint_list = []
        # print(os.listdir(folder_dir))
        # print('\n')
        # print(natsorted(os.listdir(folder_dir)))
        # exit()
        for image_name in natsorted(os.listdir(folder_dir)):
            # print(image)
            # print(image_path)
            image_path = os.path.join(folder_dir, image_name)
            # exit(0)
            # window_name = 'image'
            image = cv2.imread(image_path)
            # cv2.imshow(window_name,image)
            # cv2.waitKey(1)
            if image is not None:
                keypoints = self.hpe(image, image_name)
            keypoint_list.append(keypoints[0, :, :3])
            i = i + 1
            # print(i)
            # print(image)

        keypoint_list = filter_missing_value(
            keypoints_list=keypoint_list,
            method='ignore'  # interpolation method will be implemented later
        )

        save_file_folder = '/home/daniela/catkin_ws/src/hpe/images/' + self.dataset_name + 'cache/' + str(self.topic) + '/2d_pose.npy'
        # save_file_folder = '/home/daniela/catkin_ws/src/hpe/images/cache/with_depth/' + str(self.topic) + '/2d_pose.npy'
        # save_file_folder = '/home/daniela/catkin_ws/src/hpe/images/cache/' + str(self.topic) + '/2d_pose.npy'
        np.save(save_file_folder, keypoint_list, allow_pickle=True)
        print('Saved file with keypoint values for ' + str(self.topic) + ' with success!')

    def hpe(self, img, image_name):
        # img_height = img.shape[0]
        # img_width = img.shape[1]

        # keypoints_list_for_frame = []
        datum = op.Datum()
        op_skel = openpose_skeleton.OpenPoseSkeleton()
        datum.cvInputData = img
        self.opWrapper.emplaceAndPop(op.VectorDatum([datum]))

        # datum attributes:
        #  'cameraExtrinsics', 'cameraIntrinsics', 'cameraMatrix', 'cvInputData', 'cvOutputData', 'cvOutputData3D',
        #  'elementRendered', 'faceHeatMaps', 'faceKeypoints', 'faceKeypoints3D', 'faceRectangles', 'frameNumber',
        #  'handHeatMaps', 'handKeypoints', 'handKeypoints3D', 'handRectangles', 'id', 'inputNetData', 'name',
        #  'netInputSizes', 'netOutputSize', 'outputData', 'poseCandidates', 'poseHeatMaps', 'poseIds', 'poseKeypoints',
        #  'poseKeypoints3D', 'poseNetOutput', 'poseScores', 'scaleInputToNetInputs', 'scaleInputToOutput',
        #  'scaleNetToOutput', 'subId', 'subIdMax'
        # print(dir(datum))
        keypoints = datum.poseKeypoints
        if keypoints is None:
            keypoints = np.array(self.empty_skeleton)
        img_sk = self.vis_2d_keypoints(keypoints=keypoints[0], img=img,image_name=image_name, skeleton=op_skel, kp_thresh=0.4)

        return keypoints

    def vis_2d_keypoints(self, keypoints, img, image_name, skeleton, kp_thresh, alpha=0.7, show_name=False):
        # Convert from plt 0-1 RGBA colors to 0-255 BGR colors for opencv.

        if keypoints is not None:
            # print(keypoint_list)
            cmap = plt.get_cmap('rainbow')
            colors = [cmap(i) for i in np.linspace(0, 1, skeleton.keypoint_num)]
            colors = [(c[2] * 255, c[1] * 255, c[0] * 255) for c in colors]

            mask = img.copy()
            root = skeleton.root
            stack = [root]

            while stack:
                parent = stack.pop()
                p_idx = skeleton.keypoint2index[parent]
                p_pos = int(keypoints[p_idx, 0]), int(keypoints[p_idx, 1])
                p_score = keypoints[p_idx, 2] if kp_thresh is not None else None

                if kp_thresh is None or p_score > kp_thresh:
                    cv2.circle(
                        mask, p_pos, radius=3,
                        color=colors[p_idx], thickness=-1, lineType=cv2.LINE_AA)
                    if show_name:
                        cv2.putText(mask, parent, p_pos, cv2.FONT_HERSHEY_sim_testPLEX, 0.5, (0, 255, 0))

                for child in skeleton.children[parent]:
                    if child not in skeleton.keypoint2index or skeleton.keypoint2index[child] < 0:
                        continue
                    stack.append(child)
                    c_idx = skeleton.keypoint2index[child]
                    c_pos = int(keypoints[c_idx, 0]), int(keypoints[c_idx, 1])
                    c_score = keypoints[c_idx, 2] if kp_thresh else None
                    if kp_thresh is None or \
                            (p_score > kp_thresh and c_score > kp_thresh):
                        cv2.line(
                            mask, p_pos, c_pos,
                            color=colors[c_idx], thickness=2, lineType=cv2.LINE_AA)

            vis_result = cv2.addWeighted(img, 1.0 - alpha, mask, alpha, 0)
        if not self.debug:
            output_folder = "/home/daniela/catkin_ws/src/hpe/images/" + self.dataset_name + "cache/" + self.topic + '/' + image_name
            # output_folder = "/home/daniela/catkin_ws/src/hpe/images/cache/with_depth/" + self.topic + '/' + str(
            #     i) + '.png'
            # output_folder = "/home/daniela/catkin_ws/src/hpe/images/cache/" + self.topic + '/' + str(i) + '.png'
            cv2.imwrite(output_folder, vis_result)

        else:
            fig = plt.figure()
            ax1 = fig.add_subplot(1, 2, 1)
            connections = [[1, 2], [2, 3], [3, 4], [1, 5], [5, 6], [6, 7],
                           [1, 8], [8, 9], [9, 10], [10, 11], [11, 24], [11, 22], [22, 23], [8, 12], [12, 13], [13, 14],
                           [14, 21], [14, 19], [19, 20]]

            ax2 = fig.add_subplot(1, 2, 2)
            # ax2.set_aspect('equal', 'box')
            ax2.title.set_text('skeleton')
            ax2.set_xlim(0, 1200)
            ax2.set_ylim(800, 0)
            pose_2d = keypoints

            if pose_2d is not None:
                for _c in connections:
                    ax2.plot([pose_2d[_c[0]][0], pose_2d[_c[1]][0]], [pose_2d[_c[0]][1], pose_2d[_c[1]][1]], c='red')
                idx = 0
                for point in pose_2d:
                    if idx not in [0, 15, 16, 17, 18] and (point[0] != 0 and point[1] != 0):
                        # ax2.axes.invert_yaxis()
                        ax2.scatter(point[0], point[1], c='b')
                        ax2.text(point[0], point[1], str(idx))
                        ax2.text(point[0] + 10, point[1] + 5, str(round(point[2], 2)), c='r')
                        # ax2.invert_yaxis()
                    idx = idx + 1

            img = np.asarray(vis_result)
            plt.figtext(0.02, 0.02, str(i), fontsize=14)
            ax1.imshow(img)
            plt.show()
            plt.pause(0.01)
        return

    def write_imgs_no_bk(self, keypoints, img, skeleton, kp_thresh, alpha=0.7, show_name=False, i=0):
        # Convert from plt 0-1 RGBA colors to 0-255 BGR colors for opencv.
        # mask = np.ones(img.shape)*255
        mask = img.copy()
        if keypoints is not None:
            cmap = plt.get_cmap('rainbow')
            colors = [cmap(i) for i in np.linspace(0, 1, skeleton.keypoint_num)]
            colors = [(c[2] * 255, c[1] * 255, c[0] * 255) for c in colors]

            root = skeleton.root
            stack = [root]

            while stack:
                parent = stack.pop()
                p_idx = skeleton.keypoint2index[parent]
                p_pos = int(keypoints[p_idx, 0]), int(keypoints[p_idx, 1])
                p_score = keypoints[p_idx, 2] if kp_thresh is not None else None

                if kp_thresh is None or p_score > kp_thresh:
                    cv2.circle(
                        mask, p_pos, radius=3,
                        color=colors[p_idx], thickness=-1, lineType=cv2.LINE_AA)
                    if show_name:
                        cv2.putText(mask, parent, p_pos, cv2.FONT_HERSHEY_sim_testPLEX, 0.5, (0, 255, 0))

                for child in skeleton.children[parent]:
                    if child not in skeleton.keypoint2index or skeleton.keypoint2index[child] < 0:
                        continue
                    stack.append(child)
                    c_idx = skeleton.keypoint2index[child]
                    c_pos = int(keypoints[c_idx, 0]), int(keypoints[c_idx, 1])
                    c_score = keypoints[c_idx, 2] if kp_thresh else None
                    if kp_thresh is None or \
                            (p_score > kp_thresh and c_score > kp_thresh):
                        cv2.line(
                            mask, p_pos, c_pos,
                            color=colors[c_idx], thickness=3, lineType=cv2.LINE_AA)

            vis_result = cv2.addWeighted(img, 1.0 - alpha, mask, alpha, 0)
        elif keypoints is None:
            # print("no skeleton")
            vis_result = mask
        output_folder = "/home/daniela/catkin_ws/src/hpe/images/" + self.dataset_name + "cache/" + self.topic + '/' + str(i) + '.png'
        # output_folder = "/home/daniela/catkin_ws/src/hpe/images/cache/with_depth/" + self.topic + '/' + str(i) + '.png'
        cv2.imwrite(output_folder, vis_result)


# return vis_result


def main():
    # cameras = ['camera_1', 'camera_2', 'camera_3', 'camera_4']
    # cameras = ['camera_1', 'camera_3']
    # cameras = ['camera_2_sample']
    # cameras = ['camera_3']
    # cameras = ['tests']
    cameras = ['54138969', '55011271', '58860488', '60457274']

    for camera in cameras:
        cam = HPE(camera)
        cam.scan_all_images()
        # cam.hpe()


if __name__ == "__main__":
    main()
