#!/usr/bin/env python3

import os
import cv2
import numpy as np
from natsort import natsorted

import mediapipe as mp
from mediapipe.tasks import python
from mediapipe.tasks.python import vision

from mediapipe import solutions
from mediapipe.framework.formats import landmark_pb2

# import _ "image/jpeg"




def draw_landmarks_on_image(rgb_image, detection_result):
    pose_landmarks_list = detection_result.pose_landmarks
    annotated_image = np.copy(rgb_image)
    # print(range(len(pose_landmarks_list)))
    pose_landmarks_proto = landmark_pb2.NormalizedLandmarkList()

    # Loop through the detected poses to visualize.
    for idx in range(len(pose_landmarks_list)):
        pose_landmarks = pose_landmarks_list[idx]
        # print(pose_landmarks)

        # Draw the pose landmarks.
        pose_landmarks_proto = landmark_pb2.NormalizedLandmarkList()
        pose_landmarks_proto.landmark.extend([
            landmark_pb2.NormalizedLandmark(x=landmark.x, y=landmark.y, z=landmark.z, visibility=landmark.visibility)
            for landmark in pose_landmarks
        ])
        # print(pose_landmarks_proto.landmark)
        solutions.drawing_utils.draw_landmarks(
            annotated_image,
            pose_landmarks_proto,
            solutions.pose.POSE_CONNECTIONS,
            solutions.drawing_styles.get_default_pose_landmarks_style())
        # print(idx)
    # exit(0)
    return annotated_image, pose_landmarks_proto.landmark


class hpe_mediapipe():
    def __init__(self, camera, sector='S1', sequence='Seq1', show_images=False):
        self.model_path = '../models/mediapipe/pose_landmarker_heavy.task'
        self.camera = camera
        self.show_images = show_images
        self.dataset_name = '../images/mpi-inf-3dhp/' + sector + '/' + sequence

        print('Starting keypoint retreival for ' + str(self.camera))

        BaseOptions = mp.tasks.BaseOptions
        PoseLandmarker = mp.tasks.vision.PoseLandmarker
        PoseLandmarkerOptions = mp.tasks.vision.PoseLandmarkerOptions
        VisionRunningMode = mp.tasks.vision.RunningMode

        options = PoseLandmarkerOptions(
            base_options=BaseOptions(model_asset_path=self.model_path),
            running_mode=VisionRunningMode.IMAGE)

        self.landmarker = PoseLandmarker.create_from_options(options)

        self.dataset_folder = self.dataset_name + '/imageSequence/' + self.camera + '/'

    def scan_all_images(self):

        keypoint_list = []
        save_file_folder = '/home/daniela/catkin_ws/src/hpe_real_time/images/' + self.dataset_name + '/cache/' + str(
            self.camera) + '/2d_pose_mediapipe.npy'

        for image_name in natsorted(os.listdir(self.dataset_folder)):
            image_path = os.path.join(self.dataset_folder, image_name)
            mp_image = mp.Image.create_from_file(image_path)

            pose_landmarker_result = self.landmarker.detect(mp_image)
            annotated_image, keypoints_mp = draw_landmarks_on_image(mp_image.numpy_view(), pose_landmarker_result)

            # print(annotated_image.shape)
            height, width, c = annotated_image.shape
            # print(keypoint_list)
            keypoints = []

            for pt in keypoints_mp:
                not_normalized = solutions.drawing_utils._normalized_to_pixel_coordinates(pt.x, pt.y, width,
                                                                                          height)
                if not_normalized is not None:
                    keypoint = [not_normalized[0], not_normalized[1], pt.visibility]
                else:
                    keypoint = [0, 0, 0]
                # print(keypoint)
                keypoints.append(keypoint)

                # print(not_normalized)

            if self.show_images:
                frame = cv2.imread(image_path)
                idx = 0
                for pt in keypoints:
                    cv2.circle(frame, (pt[0], pt[1]), 0, color=(0, 0, 255), thickness=3)
                    cv2.putText(frame, str(idx), (pt[0], pt[1]), cv2.FONT_HERSHEY_PLAIN, 0.5, (0, 255, 0))
                    idx += 1

                cv2.imshow(image_name, frame)

                # cv2.imshow(image_name, cv2.cvtColor(annotated_image, cv2.COLOR_RGB2BGR))
                cv2.waitKey(0)

            keypoint_list.append(keypoints)
            # exit(0)
        # print(keypoint_list)
        np.save(save_file_folder, keypoint_list, allow_pickle=True)
        print('Saved file with keypoint values for ' + str(self.camera) + ' with success!')


# def hpe_mediapipe():
#     model_path = '../models/mediapipe/pose_landmarker_heavy.task'
#
#     BaseOptions = mp.tasks.BaseOptions
#     PoseLandmarker = mp.tasks.vision.PoseLandmarker
#     PoseLandmarkerOptions = mp.tasks.vision.PoseLandmarkerOptions
#     VisionRunningMode = mp.tasks.vision.RunningMode
#
#     options = PoseLandmarkerOptions(
#         base_options=BaseOptions(model_asset_path=model_path),
#         running_mode=VisionRunningMode.IMAGE)
#
#     with PoseLandmarker.create_from_options(options) as landmarker:
#         # print()
#
#         # Load the input image from an image file.
#         # mp_image = mp.Image.create_from_file('../images/human36m/processed/S1/Directions-1/imageSequence/54138969/img_000001.jpg')
#         mp_image = mp.Image.create_from_file(
#             '../images/human36m/processed/S1/Directions-1/imageSequence/60457274/img_000159.jpg')
#
#         pose_landmarker_result = landmarker.detect(mp_image)
#         # print(pose_landmarker_result)
#
#         annotated_image = draw_landmarks_on_image(mp_image.numpy_view(), pose_landmarker_result)
#         cv2.imshow('window', cv2.cvtColor(annotated_image, cv2.COLOR_RGB2BGR))
#         cv2.waitKey(0)
#
#         exit(0)
#
#         # print(pose_landmarker_result)


def main():
    cameras = ['camera_0', 'camera_1', 'camera_2', 'camera_4', 'camera_5', 'camera_6', 'camera_7', 'camera_8']
    show_images = False

    for camera in cameras:
        cam = hpe_mediapipe(camera, sector='S1', sequence="Seq1", show_images=show_images)
        cam.scan_all_images()
        # cam.hpe()


if __name__ == "__main__":
    main()
