#!/usr/bin/env python3
import sys

from openpose import pyopenpose as op

from pose_estimator_2d import openpose_estimator
from pose_estimator_2d.estimator_2d import Estimator2D
import cv2
import numpy as np
import os
from pathlib import Path
from bvh_skeleton import openpose_skeleton, h36m_skeleton, cmu_skeleton
import matplotlib.pyplot as plt
import functools
import operator
from PIL import Image

folder = "../images/"


def filter_missing_value(keypoints_list, method='ignore'):
    # TODO: impletemd 'interpolate' method.
    """Filter missing value in pose list.
    Args:
        keypoints_list: Estimate result returned by 2d estimator. Missing value
        will be None.
        method: 'ignore' -> drop missing value.
    Return:
        Keypoints list without missing value.
    """

    result = []
    if method == 'ignore':
        result = [pose for pose in keypoints_list if pose is not None]
    else:
        raise ValueError(f'{method} is not a valid method.')

    return result


class HPE():
    def __init__(self, topic):
        self.image_raw = None
        self.image_skeleton = None
        self.topic = topic

        params = {'model_folder': '/home/daniela/openpose/models/', 'render_pose': 0, 'model_pose': 'BODY_25',
                  'net_resolution': '-1x320'}
        self.opWrapper = op.WrapperPython()
        self.opWrapper.configure(params)
        self.opWrapper.start()
        print('Starting keypoint retreival for ' + str(self.topic))

    def scan_all_images(self):
        folder_dir = '/home/daniela/catkin_ws/src/hpe/images/with_depth/' + self.topic
        # images = Path(folder_dir).glob('*.png')
        i = 0
        keypoint_list = []
        # print(os.listdir(folder_dir))
        # print('\n')
        # print(sorted(os.listdir(folder_dir)))
        for image_path in sorted(os.listdir(folder_dir)):
            # print(image)
            image_path = os.path.join(folder_dir, image_path)

            image = cv2.imread(image_path)
            if image is not None:
                keypoints = self.hpe(image, i)
            # print(keypoint_list.shape)
            if keypoints is not None:
                keypoint_list.append(keypoints[0, :, :2])
            else:
                keypoint_list.append(keypoints)
            # keypoint_list.append(keypoint_list)
            i = i + 1
            # print(image)

        keypoint_list = filter_missing_value(
            keypoints_list=keypoint_list,
            method='ignore'  # interpolation method will be implemented later
        )
        # print(keypoint_list)
        # print(np.array(keypoint_list[:]).shape)
        # save 2d pose result
        # pose2d = np.stack(keypoint_list)[:, :, :2]
        # pose2d_file = Path(folder_dir / 'cache/2d_pose.npy')
        save_file_folder = '/home/daniela/catkin_ws/src/hpe/images/cache/with_depth/' + str(self.topic) + '/2d_pose.npy'
        np.save(save_file_folder, keypoint_list, allow_pickle=True)
        print('Saved file with keypoint values for ' + str(self.topic) + ' with success!')

    def hpe(self, img, i):
        # img_height = img.shape[0]
        # img_width = img.shape[1]

        # keypoints_list_for_frame = []
        datum = op.Datum()
        op_skel = openpose_skeleton.OpenPoseSkeleton()
        datum.cvInputData = img
        self.opWrapper.emplaceAndPop(op.VectorDatum([datum]))

        # datum attributes:
        #  'cameraExtrinsics', 'cameraIntrinsics', 'cameraMatrix', 'cvInputData', 'cvOutputData', 'cvOutputData3D',
        #  'elementRendered', 'faceHeatMaps', 'faceKeypoints', 'faceKeypoints3D', 'faceRectangles', 'frameNumber',
        #  'handHeatMaps', 'handKeypoints', 'handKeypoints3D', 'handRectangles', 'id', 'inputNetData', 'name',
        #  'netInputSizes', 'netOutputSize', 'outputData', 'poseCandidates', 'poseHeatMaps', 'poseIds', 'poseKeypoints',
        #  'poseKeypoints3D', 'poseNetOutput', 'poseScores', 'scaleInputToNetInputs', 'scaleInputToOutput',
        #  'scaleNetToOutput', 'subId', 'subIdMax'
        # print(dir(datum))
        keypoints = datum.poseKeypoints
        # print(i)
        # print(keypoint_list,img, op_skel, i)
        if keypoints is not None:
            # img_sk = self.vis_2d_keypoints(keypoint_list=keypoint_list[0], img=img, skeleton=op_skel, kp_thresh=0.4, i=i)
            img_sk = self.write_imgs_no_bk(keypoints=keypoints[0], img=img, skeleton=op_skel, kp_thresh=0.4, i=i)

            # print(keypoint_list[0].shape)

        else:
            # img_sk = self.vis_2d_keypoints(keypoint_list=None, img=img, skeleton=op_skel, kp_thresh=0.4, i=i)
            img_sk = self.write_imgs_no_bk(keypoints=None, img=img, skeleton=op_skel, kp_thresh=0.4, i=i)

        return keypoints

    def vis_2d_keypoints(self, keypoints, img, skeleton, kp_thresh, alpha=0.7, show_name=False, i=0):
        # Convert from plt 0-1 RGBA colors to 0-255 BGR colors for opencv.
        if keypoints is not None:
            cmap = plt.get_cmap('rainbow')
            colors = [cmap(i) for i in np.linspace(0, 1, skeleton.keypoint_num)]
            colors = [(c[2] * 255, c[1] * 255, c[0] * 255) for c in colors]

            mask = img.copy()
            root = skeleton.root
            stack = [root]

            while stack:
                parent = stack.pop()
                p_idx = skeleton.keypoint2index[parent]
                p_pos = int(keypoints[p_idx, 0]), int(keypoints[p_idx, 1])
                p_score = keypoints[p_idx, 2] if kp_thresh is not None else None

                if kp_thresh is None or p_score > kp_thresh:
                    cv2.circle(
                        mask, p_pos, radius=3,
                        color=colors[p_idx], thickness=-1, lineType=cv2.LINE_AA)
                    if show_name:
                        cv2.putText(mask, parent, p_pos, cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0))

                for child in skeleton.children[parent]:
                    if child not in skeleton.keypoint2index or skeleton.keypoint2index[child] < 0:
                        continue
                    stack.append(child)
                    c_idx = skeleton.keypoint2index[child]
                    c_pos = int(keypoints[c_idx, 0]), int(keypoints[c_idx, 1])
                    c_score = keypoints[c_idx, 2] if kp_thresh else None
                    if kp_thresh is None or \
                            (p_score > kp_thresh and c_score > kp_thresh):
                        cv2.line(
                            mask, p_pos, c_pos,
                            color=colors[c_idx], thickness=2, lineType=cv2.LINE_AA)

            vis_result = cv2.addWeighted(img, 1.0 - alpha, mask, alpha, 0)
        elif keypoints is None:
            # print("no skeleton")
            vis_result = img
        output_folder = "/home/daniela/catkin_ws/src/hpe/images/cache/with_depth/" + self.topic + '/' + str(i) + '.png'
        cv2.imwrite(output_folder, vis_result)
        # return vis_result

    def write_imgs_no_bk(self, keypoints, img, skeleton, kp_thresh, alpha=0.7, show_name=False, i=0):
        # Convert from plt 0-1 RGBA colors to 0-255 BGR colors for opencv.
        # mask = np.ones(img.shape)*255
        mask = img.copy()
        if keypoints is not None:
            cmap = plt.get_cmap('rainbow')
            colors = [cmap(i) for i in np.linspace(0, 1, skeleton.keypoint_num)]
            colors = [(c[2] * 255, c[1] * 255, c[0] * 255) for c in colors]

            # img_height = img.shape[0]
            # img_width = img.shape[1]

            root = skeleton.root
            stack = [root]

            while stack:
                parent = stack.pop()
                p_idx = skeleton.keypoint2index[parent]
                p_pos = int(keypoints[p_idx, 0]), int(keypoints[p_idx, 1])
                p_score = keypoints[p_idx, 2] if kp_thresh is not None else None

                if kp_thresh is None or p_score > kp_thresh:
                    cv2.circle(
                        mask, p_pos, radius=3,
                        color=colors[p_idx], thickness=-1, lineType=cv2.LINE_AA)
                    if show_name:
                        cv2.putText(mask, parent, p_pos, cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0))

                for child in skeleton.children[parent]:
                    if child not in skeleton.keypoint2index or skeleton.keypoint2index[child] < 0:
                        continue
                    stack.append(child)
                    c_idx = skeleton.keypoint2index[child]
                    c_pos = int(keypoints[c_idx, 0]), int(keypoints[c_idx, 1])
                    c_score = keypoints[c_idx, 2] if kp_thresh else None
                    if kp_thresh is None or \
                            (p_score > kp_thresh and c_score > kp_thresh):
                        cv2.line(
                            mask, p_pos, c_pos,
                            color=colors[c_idx], thickness=3, lineType=cv2.LINE_AA)

            vis_result = cv2.addWeighted(img, 1.0 - alpha, mask, alpha, 0)
        elif keypoints is None:
            # print("no skeleton")
            vis_result = mask
        output_folder = "/home/daniela/catkin_ws/src/hpe/images/cache/with_depth/" + self.topic + '/' + str(i) + '.png'
        cv2.imwrite(output_folder, vis_result)



# return vis_result


def main():
    cameras = ['camera_2', 'camera_3', 'camera_4']
    # cameras = ['camera_1', 'camera_3']
    # cameras = ['camera_2_sample']
    # cameras = ['camera_2']

    for camera in cameras:
        cam = HPE(camera)
        cam.scan_all_images()
        # cam.hpe()


if __name__ == "__main__":
    main()
