#!/usr/bin/env python3

import os
import cv2
import numpy as np
from natsort import natsorted

import mediapipe as mp
from mediapipe.tasks import python
from mediapipe.tasks.python import vision

from mediapipe import solutions
from mediapipe.framework.formats import landmark_pb2


def draw_landmarks_on_image(rgb_image, detection_result):
    pose_landmarks_list = detection_result.pose_landmarks
    annotated_image = np.copy(rgb_image)
    # print(range(len(pose_landmarks_list)))

    # Loop through the detected poses to visualize.
    for idx in range(len(pose_landmarks_list)):
        pose_landmarks = pose_landmarks_list[idx]
        # print(pose_landmarks)

        # Draw the pose landmarks.
        pose_landmarks_proto = landmark_pb2.NormalizedLandmarkList()
        pose_landmarks_proto.landmark.extend([
            landmark_pb2.NormalizedLandmark(x=landmark.x, y=landmark.y, z=landmark.z, visibility=landmark.visibility)
            for landmark in pose_landmarks
        ])
        # print(pose_landmarks_proto.landmark)
        solutions.drawing_utils.draw_landmarks(
            annotated_image,
            pose_landmarks_proto,
            solutions.pose.POSE_CONNECTIONS,
            solutions.drawing_styles.get_default_pose_landmarks_style())
        # print(idx)
    # exit(0)
    return annotated_image, pose_landmarks_proto.landmark


class hpe_mediapipe():
    def __init__(self, camera, sector='S1', action='Directions-1', show_images=False):
        self.model_path = '../models/mediapipe/pose_landmarker_heavy.task'
        self.camera = camera
        self.show_images = show_images
        self.dataset_name = '../images/human36m/processed/' + sector + '/' + action

        BaseOptions = mp.tasks.BaseOptions
        PoseLandmarker = mp.tasks.vision.PoseLandmarker
        PoseLandmarkerOptions = mp.tasks.vision.PoseLandmarkerOptions
        VisionRunningMode = mp.tasks.vision.RunningMode

        options = PoseLandmarkerOptions(
            base_options=BaseOptions(model_asset_path=self.model_path),
            running_mode=VisionRunningMode.IMAGE)

        self.landmarker = PoseLandmarker.create_from_options(options)

        self.dataset_folder = self.dataset_name + '/imageSequence/' + self.camera + '/'

    def scan_all_images(self):

        keypoint_list = []
        save_file_folder = '/home/daniela/catkin_ws/src/hpe/images/' + self.dataset_name + '/cache/' + str(
            self.camera) + '/2d_pose_mediapipe.npy'

        for image_name in natsorted(os.listdir(self.dataset_folder)):
            image_path = os.path.join(self.dataset_folder, image_name)
            mp_image = mp.Image.create_from_file(image_path)

            pose_landmarker_result = self.landmarker.detect(mp_image)
            annotated_image, keypoints_mp = draw_landmarks_on_image(mp_image.numpy_view(), pose_landmarker_result)

            # print(annotated_image.shape)
            height, width, c = annotated_image.shape
            # print(keypoint_list)
            keypoints = []

            for pt in keypoints_mp:
                not_normalized = solutions.drawing_utils._normalized_to_pixel_coordinates(pt.x, pt.y, width,
                                                                                          height)

                keypoint = [not_normalized[0], not_normalized[1], pt.visibility]
                # print(keypoint)
                keypoints.append(keypoint)

                # print(not_normalized)

            if self.show_images:
                frame = cv2.imread(image_path)
                idx = 0
                for pt in keypoints:
                    cv2.circle(frame, (pt[0], pt[1]), 0, color=(0, 0, 255), thickness=3)
                    cv2.putText(frame, str(idx), (pt[0], pt[1]), cv2.FONT_HERSHEY_PLAIN, 0.5, (0, 255, 0))
                    idx += 1

                cv2.imshow(image_name, frame)

                # cv2.imshow(image_name, cv2.cvtColor(annotated_image, cv2.COLOR_RGB2BGR))
                cv2.waitKey(0)

            keypoint_list.append(keypoints)
            exit(0)
        print(keypoint_list)
        np.save(save_file_folder, keypoint_list, allow_pickle=True)


# def hpe_mediapipe():
#     model_path = '../models/mediapipe/pose_landmarker_heavy.task'
#
#     BaseOptions = mp.tasks.BaseOptions
#     PoseLandmarker = mp.tasks.vision.PoseLandmarker
#     PoseLandmarkerOptions = mp.tasks.vision.PoseLandmarkerOptions
#     VisionRunningMode = mp.tasks.vision.RunningMode
#
#     options = PoseLandmarkerOptions(
#         base_options=BaseOptions(model_asset_path=model_path),
#         running_mode=VisionRunningMode.IMAGE)
#
#     with PoseLandmarker.create_from_options(options) as landmarker:
#         # print()
#
#         # Load the input image from an image file.
#         # mp_image = mp.Image.create_from_file('../images/human36m/processed/S1/Directions-1/imageSequence/54138969/img_000001.jpg')
#         mp_image = mp.Image.create_from_file(
#             '../images/human36m/processed/S1/Directions-1/imageSequence/60457274/img_000159.jpg')
#
#         pose_landmarker_result = landmarker.detect(mp_image)
#         # print(pose_landmarker_result)
#
#         annotated_image = draw_landmarks_on_image(mp_image.numpy_view(), pose_landmarker_result)
#         cv2.imshow('window', cv2.cvtColor(annotated_image, cv2.COLOR_RGB2BGR))
#         cv2.waitKey(0)
#
#         exit(0)
#
#         # print(pose_landmarker_result)


def main():
    cameras = ['54138969', '55011271', '58860488', '60457274']
    show_images = True

    for camera in cameras:
        cam = hpe_mediapipe(camera, sector='S1', action="Directions-1", show_images=show_images)
        cam.scan_all_images()
        # cam.hpe()


if __name__ == "__main__":
    main()
