import copy
import os
import json
import argparse
import math
import numpy as np
import cv2
import sys
import argparse
from functools import partial
from colorama import Style, Fore
import datetime
import random

from natsort import natsorted

sys.path.insert(1, '../utils')
from links import links_mpi_inf_3dhp, joints_mpi_inf_3dhp


def update_image(window_name, image_src, keypoints_to_dlt):
    image = copy.deepcopy(image_src)
    for keypoint in keypoints_to_dlt:
        # print(keypoints_to_dlt)
        x = int(keypoint[0])
        y = int(keypoint[1])
        cv2.circle(image, (x, y), 0, color=(0, 0, 255), thickness=10)

    image_show = cv2.resize(image, (int(image.shape[1] / 2), int(image.shape[0] / 2)))
    return image_show


def save_file(folder, keypoints, output_path, poses_new, pose_idx):
    if os.path.exists(folder):
        update_poses(poses_new, keypoints, pose_idx)
        # output_path = folder + '2d_pose_generated_' + timestring + '.npy'
        np.save(output_path, poses_new, allow_pickle=True)

        print("Saving changes to " + output_path)
    return poses_new


def update_poses(poses, keypoints, pose_idx):
    poses[pose_idx] = keypoints
    return poses


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("-dataset", "--dataset_name",
                    help="Dataset name. Datasets must be inside hpe/images/. Please use -dataset human36m to optimize Human 3.6M.",
                    type=str, default='mpi')
    ap.add_argument("-s", "--section",
                    help="Human 3.6M section to optimize. Only works when the dataset is set to human36m",
                    type=str, default='S1')
    ap.add_argument("-seq", "--sequence",
                    help="MPI sequence optimize. Only works when the dataset is set to human36m",
                    type=str, default='Seq1')
    ap.add_argument("-r", "--random_generation", help="Randomly generates errors and occlusions", action="store_true",
                    default=False)
    ap.add_argument("-e", "--error", help="Error in pixels to add to selected keypoint_list.",
                    type=int, default=2)
    ap.add_argument("-j", "--max_number_joints_to_delete", help="Maximum number of joints to delete.",
                    type=int)
    ap.add_argument("-max", "--max_number_of_frames",
                    help="Limits the number of frames and creates a smaller subset of the dataset section.",
                    type=int)
    ap.add_argument("-sf", "--start_frame",
                    help="Frame to start optimization.",
                    type=int, default=0)
    ap.add_argument("-pj", "--percentage_joints_error",
                    help="Percentage of joints that have error.",
                    type=int, default=30)
    ap.add_argument("-cams", "--cameras_to_use", help="Cameras to use in optimization", nargs='+',
                    default=['camera_0', 'camera_2', 'camera_4', 'camera_5', 'camera_7', 'camera_8'])

    ap.add_argument("-si", "--show_images", help="Show images in random generation", action="store_true",
                    default=False)
    ap.add_argument("-ofj", "--occlude_fixed_joint", help="Occlude fixed joint", action="store_true",
                    default=False)
    args = vars(ap.parse_args())

    dataset_folder = '../../images/' + args['dataset_name'] + '/'

    mpi = False
    human36m = False
    occlude_fixed_joint = args['occlude_fixed_joint']

    # Get timestamp to add to output file
    tz = datetime.timezone.utc
    # ft = "%Y-%m-%d_%H:%M"
    ft = "%d-%m-%Y_%H:%M"
    timestring = datetime.datetime.now(tz=tz).strftime(ft)

    # Verifications
    if args['dataset_name'] == "mpi":
        mpi = True
        dataset_folder = '../../images/mpi-inf-3dhp/' + args['section'] + '/' + args['sequence'] + '/'

    cameras = args['cameras_to_use']
    start_frame = args['start_frame']

    for camera in cameras:

        next_camera = False
        print("Generating occlusions for " + str(camera))
        camera_folder = dataset_folder + 'imageSequence/' + camera + '/'
        poses_2d_folder = dataset_folder + 'cache/' + camera + '/2d_pose_gt.npy'
        output_folder = dataset_folder + 'cache/' + camera + '/'
        poses_2d_original = np.load(poses_2d_folder, allow_pickle=True)
        poses_new = copy.deepcopy(poses_2d_original)
        pose_idx = 0

        # Process for manually deleting keypoint_list
        if not args['random_generation']:
            output_path = output_folder + '2d_pose_generated_occlusion.npy'
            right_arm = ['RElbow', 'RWrist', 'RHand']
            left_arm = ['LElbow', 'LWrist', 'LHand']
            right_leg = ['RKnee', 'RAnkle', 'RFoot']
            left_leg = ['LKnee', 'LAnkle', 'LFoot']

            for image_name in natsorted(os.listdir(camera_folder)):
                if args['max_number_of_frames'] and pose_idx > args['max_number_of_frames']:
                    break

                image_path = os.path.join(camera_folder, image_name)
                image_org = cv2.imread(image_path)
                image = copy.deepcopy(image_org)
                image_gui = copy.deepcopy(image_org)
                keypoints = poses_2d_original[pose_idx]
                keypoints_new = poses_2d_original[pose_idx]

                # draw initial keypoint_list in green
                for keypoint in keypoints:
                    x = int(keypoint[0])
                    y = int(keypoint[1])
                    cv2.circle(image_gui, (x, y), 0, color=(0, 255, 0), thickness=10)

                # resize image for visualization
                image_show = cv2.resize(image_gui, (int(image.shape[1] / 2), int(image.shape[0] / 2)))
                keypoints_to_dlt = []

                while next_camera == False:
                    window_name = 'Frame ' + str(pose_idx) + '_' + camera
                    cv2.imshow(window_name, image_show)
                    key = cv2.waitKey(0) & 0xFF  # right arrow is 83

                    if key == ord('r'):
                        while True:
                            key = cv2.waitKey(0) & 0xFF
                            n = 0
                            if key == ord('a'):
                                print(
                                    "Deleting " + Style.BRIGHT + Fore.BLUE + "right arm " + Style.RESET_ALL + "for frame " + str(
                                        pose_idx))
                                for joint in right_arm:
                                    joint_idx = joints_mpi_inf_3dhp[joint]['idx']
                                    kpt_to_dlt = keypoints[joint_idx]
                                    keypoints_to_dlt.append(list(kpt_to_dlt))
                                    keypoints_new[joint_idx] = [0, 0, 0]
                                    n += 1
                                print("Deleting " + str(n) + " keypoint_list")
                                image_show = update_image(window_name, image_gui, keypoints_to_dlt)
                                cv2.imshow(window_name, image_show)
                                break
                            elif key == ord('l'):
                                print(
                                    "Deleting " + Style.BRIGHT + Fore.BLUE + "right leg " + Style.RESET_ALL + "for frame " + str(
                                        pose_idx))
                                for joint in right_leg:
                                    joint_idx = joints_mpi_inf_3dhp[joint]['idx']
                                    kpt_to_dlt = keypoints[joint_idx]
                                    keypoints_to_dlt.append(list(kpt_to_dlt))
                                    keypoints_new[joint_idx] = [0, 0, 0]
                                    n += 1
                                print("Deleting " + str(n) + " keypoint_list")
                                image_show = update_image(window_name, image_gui, keypoints_to_dlt)
                                cv2.imshow(window_name, image_show)
                                break
                            elif key == ord('s'):
                                print(
                                    "Deleting " + Style.BRIGHT + Fore.BLUE + "right arm and leg " + Style.RESET_ALL + "for frame " + str(
                                        pose_idx))
                                for joint in right_leg:
                                    joint_idx = joints_mpi_inf_3dhp[joint]['idx']
                                    kpt_to_dlt = keypoints[joint_idx]
                                    keypoints_to_dlt.append(list(kpt_to_dlt))
                                    keypoints_new[joint_idx] = [0, 0, 0]
                                    n += 1
                                for joint in right_arm:
                                    joint_idx = joints_mpi_inf_3dhp[joint]['idx']
                                    kpt_to_dlt = keypoints[joint_idx]
                                    keypoints_to_dlt.append(list(kpt_to_dlt))
                                    keypoints_new[joint_idx] = [0, 0, 0]
                                    n += 1
                                print("Deleting " + str(n) + " keypoint_list")
                                image_show = update_image(window_name, image_gui, keypoints_to_dlt)
                                cv2.imshow(window_name, image_show)
                                break
                            elif key == 83:  # move to next image.
                                cv2.destroyWindow(window_name)
                                break
                            elif key == ord('c'):  # move to next camera.
                                print("Moving to next camera")
                                cv2.destroyWindow(window_name)
                                next_camera = True
                                break
                            elif key == ord('q'):  # quit without saving.
                                print('Quitting ...')
                                exit(0)
                            # print(len(keypoint_list))
                            # print(len(keypoints_new))
                            # print("ahsudhuashdka")
                            poses_new = save_file(output_folder, keypoints_new, output_path, poses_new, pose_idx)
                    elif key == ord('l'):
                        while True:
                            key = cv2.waitKey(0) & 0xFF
                            n = 0
                            if key == ord('a'):
                                print(
                                    "Deleting " + Style.BRIGHT + Fore.BLUE + "left arm " + Style.RESET_ALL + "for frame " + str(
                                        pose_idx))
                                for joint in left_arm:
                                    joint_idx = joints_mpi_inf_3dhp[joint]['idx']
                                    kpt_to_dlt = keypoints[joint_idx]
                                    keypoints_to_dlt.append(list(kpt_to_dlt))
                                    keypoints_new[joint_idx] = [0, 0, 0]
                                    n += 1
                                print("Deleting " + str(n) + " keypoint_list")
                                image_show = update_image(window_name, image_gui, keypoints_to_dlt)
                                cv2.imshow(window_name, image_show)
                                break
                            elif key == ord('l'):
                                print(
                                    "Deleting " + Style.BRIGHT + Fore.BLUE + "left leg " + Style.RESET_ALL + "for frame " + str(
                                        pose_idx))
                                for joint in left_leg:
                                    joint_idx = joints_mpi_inf_3dhp[joint]['idx']
                                    kpt_to_dlt = keypoints[joint_idx]
                                    keypoints_to_dlt.append(list(kpt_to_dlt))
                                    keypoints_new[joint_idx] = [0, 0, 0]
                                    n += 1
                                print("Deleting " + str(n) + " keypoint_list")
                                image_show = update_image(window_name, image_gui, keypoints_to_dlt)
                                cv2.imshow(window_name, image_show)
                                break
                            elif key == ord('s'):
                                print(
                                    "Deleting " + Style.BRIGHT + Fore.BLUE + "left arm and leg " + Style.RESET_ALL + "for frame " + str(
                                        pose_idx))
                                for joint in left_arm:
                                    joint_idx = joints_mpi_inf_3dhp[joint]['idx']
                                    kpt_to_dlt = keypoints[joint_idx]
                                    keypoints_to_dlt.append(list(kpt_to_dlt))
                                    keypoints_new[joint_idx] = [0, 0, 0]
                                for joint in left_leg:
                                    joint_idx = joints_mpi_inf_3dhp[joint]['idx']
                                    kpt_to_dlt = keypoints[joint_idx]
                                    keypoints_to_dlt.append(list(kpt_to_dlt))
                                    keypoints_new[joint_idx] = [0, 0, 0]
                                    n += 1
                                print("Deleting " + str(n) + " keypoint_list")
                                image_show = update_image(window_name, image_gui, keypoints_to_dlt)
                                cv2.imshow(window_name, image_show)
                                break
                            elif key == 83:  # move to next image.
                                cv2.destroyWindow(window_name)
                                break
                            elif key == ord('c'):  # move to next camera.
                                print("Moving to next camera")
                                cv2.destroyWindow(window_name)
                                next_camera = True
                                break
                            elif key == ord('q'):  # quit without saving.
                                print('Quitting ...')
                                # exit(0)
                            poses_new = save_file(output_folder, keypoints_new, output_path, poses_new, pose_idx)

                    elif key == 83:  # move to next image.
                        cv2.destroyWindow(window_name)
                        break
                    elif key == ord('c'):  # move to next camera.
                        print("Moving to next camera")
                        cv2.destroyWindow(window_name)
                        next_camera = True
                        break
                    elif key == ord('q'):  # quit without saving.
                        print('Quitting ...')
                        exit(0)
                    poses_new = save_file(output_folder, keypoints_new, output_path, poses_new, pose_idx)

                if next_camera == True:
                    break
                pose_idx += 1

        # Process for randomly generate occlusions and errors
        if args['random_generation']:
            if occlude_fixed_joint:
                output_path = output_folder + '2d_pose_random_f' + str(args['max_number_of_frames']) + '_e' + str(
                    args['error']) + '_j' + str(args['max_number_joints_to_delete']) + '_pj' + str(
                    args['percentage_joints_error']) + '_st_' + str(start_frame) + '_ofj' + '.npy'
            else:
                output_path = output_folder + '2d_pose_random_f' + str(args['max_number_of_frames']) + '_e' + str(
                    args['error']) + '_j' + str(args['max_number_joints_to_delete']) + '_pj' + str(
                    args['percentage_joints_error']) + '_st_' + str(start_frame) + '.npy'

            pose_idx = start_frame
            for image_name in natsorted(os.listdir(camera_folder)):
                if args['max_number_of_frames'] and pose_idx > start_frame + args[
                    'max_number_of_frames'] and pose_idx >= start_frame:
                    break
                keypoints = poses_2d_original[pose_idx]
                keypoints_new = poses_2d_original[pose_idx]

                error = args['error']
                max_n_joints = args['max_number_joints_to_delete']
                # print(len(keypoint_list))
                if max_n_joints > len(keypoints):
                    raise Exception(
                        "The number of joints to occlude can't be higher the the number of joints in the skeleton!")

                # generate x and Y errors
                if error != 0:
                    idx = 0
                    np.set_printoptions(suppress=True)
                    for pt in keypoints:
                        choice = random.uniform(0, 1)  # generate random number to choose which keypoint_list to modify
                        limiar = args[
                                     'percentage_joints_error'] / 100  # limiar that tells which percentage of keypoint_list will have errors

                        # choose randomly which keypoint_list will have noise. if 1 put noise, if not don't
                        if choice < limiar:
                            v = np.random.uniform(-1.0, 1.0, 2)
                            v = v / np.linalg.norm(v)
                            x, y = pt[:2] + v * error
                            x = int(x)
                            y = int(y)

                            conf_decrease = (abs(pt[0] - x) / pt[0]) + (abs(pt[1] - x) / pt[1])
                            conf = pt[2] - conf_decrease
                            keypoints_new[idx] = [x, y, conf]

                        idx += 1

                    poses_new = update_poses(poses_new, keypoints_new, pose_idx)

                # generate random occlusions
                # num_to_select = random.randint(0, max_n_joints)  # set the random number of joints to occlude up to the choosen max
                num_to_select = max_n_joints
                # print("Number of joints to occlude is " + str(max_n_joints))

                if occlude_fixed_joint:
                    if num_to_select != 0:
                        # list_of_random_idx = random.sample(list([i for i, _ in enumerate(keypoint_list)]).del(10), num_to_select)
                        list_of_random_idx = random.sample(list([i for i, _ in enumerate(keypoints) if i not in [3,8, 10,13,22,27]]), num_to_select)
                        for index in list_of_random_idx:
                            keypoints_new[index] = [0, 0, 0]
                    #
                    # if (pose_idx > 3 + start_frame and pose_idx < 7 + start_frame) or (pose_idx > 13 + start_frame and pose_idx < 17 + start_frame):
                    #     keypoints_new[10] = [0, 0, 0]
                    # poses_new = update_poses(poses_new, keypoints_new, pose_idx)

                    if (pose_idx > 10 and pose_idx <= 15) or (pose_idx > 25 and pose_idx <= 30) or (
                            pose_idx > 40 and pose_idx <= 45) or (pose_idx > 55 and pose_idx <= 60) or (
                            pose_idx > 70 and pose_idx <= 75) or (pose_idx > 85 and pose_idx <= 90):
                        keypoints_new[10] = [0, 0, 0]
                    poses_new = update_poses(poses_new, keypoints_new, pose_idx)
                else:
                    if num_to_select != 0:
                        list_of_random_idx = random.sample(list([i for i, _ in enumerate(keypoints) if i not in [3,8,13,22,27]]), num_to_select)
                        # print(list([i for i, _ in enumerate(keypoint_list) if i not in [3,8,13,22,27]]))
                        # print(list_of_random_idx)
                        # exit(0)
                        for index in list_of_random_idx:
                            keypoints_new[index] = [0, 0, 0]
                        poses_new = update_poses(poses_new, keypoints_new, pose_idx)


                #
                # if not occlude_fixed_joint:
                #     if num_to_select != 0:
                #         list_of_random_idx = random.sample(list([i for i, _ in enumerate(keypoint_list)]), num_to_select)
                #         for index in list_of_random_idx:
                #             keypoints_new[index] = [0, 0, 0]
                # else:
                #     if pose_idx > 2 + start_frame and pose_idx < 8 + start_frame:
                #         keypoints_new[10] = [0, 0, 0]
                #     poses_new = update_poses(poses_new, keypoints_new, pose_idx)

                pose_idx += 1

            save_file(output_folder, keypoints_new, output_path, poses_new, pose_idx)


if __name__ == "__main__":
    main()
