import copy
import os
import json
import argparse
import math
import numpy as np
import cv2
import sys
import argparse
from functools import partial
from colorama import Style, Fore
import datetime
import random

from natsort import natsorted

sys.path.insert(1, '../utils')
from links import links_mpi_inf_3dhp, joints_mpi_inf_3dhp


def update_image(window_name, image_src, keypoints_to_dlt):
    image = copy.deepcopy(image_src)
    for keypoint in keypoints_to_dlt:
        # print(keypoints_to_dlt)
        x = int(keypoint[0])
        y = int(keypoint[1])
        cv2.circle(image, (x, y), 0, color=(0, 0, 255), thickness=10)

    image_show = cv2.resize(image, (int(image.shape[1] / 2), int(image.shape[0] / 2)))
    return image_show


def save_file(folder, keypoints, output_path, poses_new, pose_idx):
    if os.path.exists(folder):
        update_poses(poses_new, keypoints, pose_idx)
        # output_path = folder + '2d_pose_generated_' + timestring + '.npy'
        np.save(output_path, poses_new, allow_pickle=True)

        print("Saving changes to " + output_path)
    return poses_new


def update_poses(poses, keypoints, pose_idx):
    poses[pose_idx] = keypoints
    return poses


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("-dataset", "--dataset_name",
                    help="Dataset name. Datasets must be inside hpe/images/. Please use -dataset human36m to optimize Human 3.6M.",
                    type=str, default='mpi')
    ap.add_argument("-s", "--section",
                    help="Human 3.6M section to optimize. Only works when the dataset is set to human36m",
                    type=str, default='S1')
    ap.add_argument("-seq", "--sequence",
                    help="MPI sequence optimize. Only works when the dataset is set to human36m",
                    type=str, default='Seq1')
    ap.add_argument("-r", "--random_generation", help="Randomly generates errors and occlusions", action="store_true",
                    default=False)
    ap.add_argument("-e", "--error", help="Error in pixels to add to selected keypoints.",
                    type=int, default=2)
    ap.add_argument("-j", "--max_number_joints_to_delete", help="Maximum number of joints to delete.",
                    type=int)
    ap.add_argument("-max", "--max_number_of_frames",
                    help="Limits the number of frames and creates a smaller subset of the dataset section.",
                    type=int)
    ap.add_argument("-pj", "--percentage_joints_error",
                    help="Percentage of joints that have error.",
                    type=int, default=30)
    ap.add_argument("-cams", "--cameras_to_use", help="Cameras to use in optimization", nargs='+',
                    default=['camera_0', 'camera_4', 'camera_5', 'camera_8'])

    ap.add_argument("-si", "--show_images", help="Show images in random generation", action="store_true",
                    default=False)
    args = vars(ap.parse_args())

    dataset_folder = '../../images/' + args['dataset_name'] + '/'

    mpi = False
    human36m = False

    # Get timestamp to add to output file
    tz = datetime.timezone.utc
    # ft = "%Y-%m-%d_%H:%M"
    ft = "%d-%m-%Y_%H:%M"
    timestring = datetime.datetime.now(tz=tz).strftime(ft)

    # Verifications
    if args['dataset_name'] == "mpi":
        mpi = True
        dataset_folder = '../../images/mpi-inf-3dhp/' + args['section'] + '/' + args['sequence'] + '/'

    cameras = args['cameras_to_use']

    for camera in cameras:

        next_camera = False
        print("Generating occlusions for " + str(camera))
        camera_folder = dataset_folder + 'imageSequence/' + camera + '/'
        poses_2d_folder = dataset_folder + 'cache/' + camera + '/2d_pose_gt.npy'
        output_folder = dataset_folder + 'cache/' + camera + '/'
        poses_2d_original = np.load(poses_2d_folder, allow_pickle=True)
        poses_new = copy.deepcopy(poses_2d_original)
        pose_idx = 0

        # Process for manually deleting keypoints
        if not args['random_generation']:
            output_path = output_folder + '2d_pose_generated_occlusion.npy'
            right_arm = ['RElbow', 'RWrist', 'RHand']
            left_arm = ['LElbow', 'LWrist', 'LHand']
            right_leg = ['RKnee', 'RAnkle', 'RFoot']
            left_leg = ['LKnee', 'LAnkle', 'LFoot']

            for image_name in natsorted(os.listdir(camera_folder)):
                if args['max_number_of_frames'] and pose_idx > args['max_number_of_frames']:
                    break

                image_path = os.path.join(camera_folder, image_name)
                image_org = cv2.imread(image_path)
                image = copy.deepcopy(image_org)
                image_gui = copy.deepcopy(image_org)
                keypoints = poses_2d_original[pose_idx]
                keypoints_new = poses_2d_original[pose_idx]

                # draw initial keypoints in green
                for keypoint in keypoints:
                    x = int(keypoint[0])
                    y = int(keypoint[1])
                    cv2.circle(image_gui, (x, y), 0, color=(0, 255, 0), thickness=10)

                # resize image for visualization
                image_show = cv2.resize(image_gui, (int(image.shape[1] / 2), int(image.shape[0] / 2)))
                keypoints_to_dlt = []

                while next_camera == False:
                    window_name = 'Frame ' + str(pose_idx) + '_' + camera
                    cv2.imshow(window_name, image_show)
                    key = cv2.waitKey(0) & 0xFF  # right arrow is 83

                    if key == ord('r'):
                        while True:
                            key = cv2.waitKey(0) & 0xFF
                            n = 0
                            if key == ord('a'):
                                print(
                                    "Deleting " + Style.BRIGHT + Fore.BLUE + "right arm " + Style.RESET_ALL + "for frame " + str(
                                        pose_idx))
                                for joint in right_arm:
                                    joint_idx = joints_mpi_inf_3dhp[joint]['idx']
                                    kpt_to_dlt = keypoints[joint_idx]
                                    keypoints_to_dlt.append(list(kpt_to_dlt))
                                    keypoints_new[joint_idx] = [0, 0, 0]
                                    n += 1
                                print("Deleting " + str(n) + " keypoints")
                                image_show = update_image(window_name, image_gui, keypoints_to_dlt)
                                cv2.imshow(window_name, image_show)
                                break
                            elif key == ord('l'):
                                print(
                                    "Deleting " + Style.BRIGHT + Fore.BLUE + "right leg " + Style.RESET_ALL + "for frame " + str(
                                        pose_idx))
                                for joint in right_leg:
                                    joint_idx = joints_mpi_inf_3dhp[joint]['idx']
                                    kpt_to_dlt = keypoints[joint_idx]
                                    keypoints_to_dlt.append(list(kpt_to_dlt))
                                    keypoints_new[joint_idx] = [0, 0, 0]
                                    n += 1
                                print("Deleting " + str(n) + " keypoints")
                                image_show = update_image(window_name, image_gui, keypoints_to_dlt)
                                cv2.imshow(window_name, image_show)
                                break
                            elif key == ord('s'):
                                print(
                                    "Deleting " + Style.BRIGHT + Fore.BLUE + "right arm and leg " + Style.RESET_ALL + "for frame " + str(
                                        pose_idx))
                                for joint in right_leg:
                                    joint_idx = joints_mpi_inf_3dhp[joint]['idx']
                                    kpt_to_dlt = keypoints[joint_idx]
                                    keypoints_to_dlt.append(list(kpt_to_dlt))
                                    keypoints_new[joint_idx] = [0, 0, 0]
                                    n += 1
                                for joint in right_arm:
                                    joint_idx = joints_mpi_inf_3dhp[joint]['idx']
                                    kpt_to_dlt = keypoints[joint_idx]
                                    keypoints_to_dlt.append(list(kpt_to_dlt))
                                    keypoints_new[joint_idx] = [0, 0, 0]
                                    n += 1
                                print("Deleting " + str(n) + " keypoints")
                                image_show = update_image(window_name, image_gui, keypoints_to_dlt)
                                cv2.imshow(window_name, image_show)
                                break
                            elif key == 83:  # move to next image.
                                cv2.destroyWindow(window_name)
                                break
                            elif key == ord('c'):  # move to next camera.
                                print("Moving to next camera")
                                cv2.destroyWindow(window_name)
                                next_camera = True
                                break
                            elif key == ord('q'):  # quit without saving.
                                print('Quitting ...')
                                exit(0)
                            # print(len(keypoints))
                            # print(len(keypoints_new))
                            # print("ahsudhuashdka")
                            poses_new = save_file(output_folder, keypoints_new, output_path, poses_new, pose_idx)
                    elif key == ord('l'):
                        while True:
                            key = cv2.waitKey(0) & 0xFF
                            n = 0
                            if key == ord('a'):
                                print(
                                    "Deleting " + Style.BRIGHT + Fore.BLUE + "left arm " + Style.RESET_ALL + "for frame " + str(
                                        pose_idx))
                                for joint in left_arm:
                                    joint_idx = joints_mpi_inf_3dhp[joint]['idx']
                                    kpt_to_dlt = keypoints[joint_idx]
                                    keypoints_to_dlt.append(list(kpt_to_dlt))
                                    keypoints_new[joint_idx] = [0, 0, 0]
                                    n += 1
                                print("Deleting " + str(n) + " keypoints")
                                image_show = update_image(window_name, image_gui, keypoints_to_dlt)
                                cv2.imshow(window_name, image_show)
                                break
                            elif key == ord('l'):
                                print(
                                    "Deleting " + Style.BRIGHT + Fore.BLUE + "left leg " + Style.RESET_ALL + "for frame " + str(
                                        pose_idx))
                                for joint in left_leg:
                                    joint_idx = joints_mpi_inf_3dhp[joint]['idx']
                                    kpt_to_dlt = keypoints[joint_idx]
                                    keypoints_to_dlt.append(list(kpt_to_dlt))
                                    keypoints_new[joint_idx] = [0, 0, 0]
                                    n += 1
                                print("Deleting " + str(n) + " keypoints")
                                image_show = update_image(window_name, image_gui, keypoints_to_dlt)
                                cv2.imshow(window_name, image_show)
                                break
                            elif key == ord('s'):
                                print(
                                    "Deleting " + Style.BRIGHT + Fore.BLUE + "left arm and leg " + Style.RESET_ALL + "for frame " + str(
                                        pose_idx))
                                for joint in left_arm:
                                    joint_idx = joints_mpi_inf_3dhp[joint]['idx']
                                    kpt_to_dlt = keypoints[joint_idx]
                                    keypoints_to_dlt.append(list(kpt_to_dlt))
                                    keypoints_new[joint_idx] = [0, 0, 0]
                                for joint in left_leg:
                                    joint_idx = joints_mpi_inf_3dhp[joint]['idx']
                                    kpt_to_dlt = keypoints[joint_idx]
                                    keypoints_to_dlt.append(list(kpt_to_dlt))
                                    keypoints_new[joint_idx] = [0, 0, 0]
                                    n += 1
                                print("Deleting " + str(n) + " keypoints")
                                image_show = update_image(window_name, image_gui, keypoints_to_dlt)
                                cv2.imshow(window_name, image_show)
                                break
                            elif key == 83:  # move to next image.
                                cv2.destroyWindow(window_name)
                                break
                            elif key == ord('c'):  # move to next camera.
                                print("Moving to next camera")
                                cv2.destroyWindow(window_name)
                                next_camera = True
                                break
                            elif key == ord('q'):  # quit without saving.
                                print('Quitting ...')
                                # exit(0)
                            poses_new = save_file(output_folder, keypoints_new, output_path, poses_new, pose_idx)

                    elif key == 83:  # move to next image.
                        cv2.destroyWindow(window_name)
                        break
                    elif key == ord('c'):  # move to next camera.
                        print("Moving to next camera")
                        cv2.destroyWindow(window_name)
                        next_camera = True
                        break
                    elif key == ord('q'):  # quit without saving.
                        print('Quitting ...')
                        exit(0)
                    poses_new = save_file(output_folder, keypoints_new, output_path, poses_new, pose_idx)

                if next_camera == True:
                    break
                pose_idx += 1

        # Process for randomly generate occlusions and errors
        if args['random_generation']:
            output_path = output_folder + '2d_pose_random_f' + str(args['max_number_of_frames']) + '_e' + str(
                args['error']) + '_j' + str(args['max_number_joints_to_delete']) + '_pj' + str(
                args['percentage_joints_error']) + '.npy'

            pose_idx = 0
            for image_name in natsorted(os.listdir(camera_folder)):
                if args['max_number_of_frames'] and pose_idx > args['max_number_of_frames']:
                    break
                keypoints = poses_2d_original[pose_idx]
                keypoints_new = poses_2d_original[pose_idx]

                error = args['error']
                max_n_joints = args['max_number_joints_to_delete']
                # print(len(keypoints))
                if max_n_joints > len(keypoints):
                    raise Exception(
                        "The number of joints to occlude can't be higher the the number of joints in the skeleton!")

                # generate x and Y errors
                idx = 0
                np.set_printoptions(suppress=True)
                for pt in keypoints:
                    choice = random.uniform(0, 1)  # generate random number to choose which keypoints to modify
                    limiar = args[
                                 'percentage_joints_error'] / 100  # limiar that tells which percentage of keypoints will have errors

                    # choose randomly which keypoints will have noise. if 1 put noise, if not don't
                    if choice < limiar:
                        v = np.random.uniform(-1.0, 1.0, 2)
                        v = v / np.linalg.norm(v)
                        x, y = pt[:2] + v * error
                        x = int(x)
                        y = int(y)

                        conf_decrease = (abs(pt[0] - x) / pt[0]) + (abs(pt[1] - x) / pt[1])
                        conf = pt[2] - conf_decrease
                        keypoints_new[idx] = [x, y, conf]

                    idx += 1

                poses_new = update_poses(poses_new, keypoints_new, pose_idx)

                # generate random occlusions
                # num_to_select = random.randint(0, max_n_joints)  # set the random number of joints to occlude up to the choosen max
                num_to_select = max_n_joints

                if num_to_select != 0:
                    list_of_random_idx = random.sample(list([i for i, _ in enumerate(keypoints)]), num_to_select)
                    for index in list_of_random_idx:
                        keypoints_new[index] = [0, 0, 0]

                poses_new = update_poses(poses_new, keypoints_new, pose_idx)

                pose_idx += 1

            save_file(output_folder, keypoints_new, output_path, poses_new, pose_idx)


if __name__ == "__main__":
    main()
