#!/usr/bin/env python3

import argparse
import math
import os
import json
import sys
from functools import partial
from copy import deepcopy
from pathlib import Path

import cv2
import numpy as np
import matplotlib.pyplot as plt
import yaml
from yaml import SafeLoader
from natsort import natsorted

if os.environ.get('USER') == 'mike':
    from OptimizationUtils import Optimizer
    from KeyPressManager import WindowManager
else:
    import OptimizationUtils.OptimizationUtils as OptimizationUtils
    from OptimizationUtils.OptimizationUtils import Optimizer
    from OptimizationUtils.KeyPressManager import WindowManager

from mpl_toolkits.mplot3d import Axes3D

sys.path.append(str(Path("../").resolve()))  # add previous folder to python path

from utils.transforms import getTransform, get_transform_tree_dict
from utils.draw import drawSquare2D, drawCross2D, drawCircle, drawDiagonalCross2D
from utils.links import links_ground_truth, links, joint_idxs, intial_estimate_joints
from utils.projection import projectToCamera
from objective_function import objectiveFunction


def createJSONFile(output_file, D):
    print("Saving the json output file to " + str(output_file) + ", please wait, it could take a while ...")
    f = open(output_file, 'w')
    json.encoder.FLOAT_REPR = lambda f: ("%.6f" % f)  # to get only four decimal places on the json file
    # print >> f, json.dumps(D, indent=2, sort_keys=True)
    print(f, json.dumps(D, indent=2, sort_keys=True))
    f.close()
    print("Completed.")


def main():
    # --------------------------------------------------------
    # Arguments
    # --------------------------------------------------------
    ap = argparse.ArgumentParser()
    ap.add_argument("-sff", "--skip_frame_to_frame_residuals", help="Use frame to frame residuals", action="store_true",
                    default=False)
    ap.add_argument("-sll", "--skip_link_length_residuals", help="Use link length residuals", action="store_true",
                    default=False)
    ap.add_argument("-gt", "--has_ground_truth", help="Use groundtruth in visualization", action="store_true",
                    default=False)
    ap.add_argument("-db", "--debug", help="Prints debug lines", action="store_true",
                    default=False)
    ap.add_argument("-si", "--show_images", help="Show optimization images", action="store_true",
                    default=False)
    ap.add_argument("-pe", "--phased_execution", help="After optimization finished keep windows running",
                    action="store_true",
                    default=True)
    ap.add_argument("-o", "--save_output", help="Save output json file with final skeleton and poses",
                    action="store_true",
                    default=False)
    ap.add_argument("-all", "--optimize_all",
                    help=" Use to optimize the entire dataset/video. If this argument is set, the program ignores the -frames flag.",
                    action="store_true",
                    default=False)  # TODO if -all is set to true, don't plot images and save video in the end of the optimization
    ap.add_argument("-dataset_name", "--dataset_name",
                    help="Dataset name. Datasets must be inside hpe/images/.",
                    type=str, default='sim_moving')
    ap.add_argument("-xacro", "--xacro_name", help="Name of xacro file. Must be inside folder xacros/.", type=str,
                    default="novo.urdf.xacro")
    ap.add_argument("-frames", "--frames_to_use", help="Frames to use in optimization. ", nargs='+', type=int,
                    default=[1])
    ap.add_argument("-cams", "--cameras_to_use", help="Cameras to use in optimization", nargs='+',
                    default=["camera_1", "camera_2", "camera_3", "camera_4"])

    args = vars(ap.parse_args())
    # --------------------------------------------------------
    # Configuration and verifications
    # --------------------------------------------------------
    world_frame = 'world'

    dataset_folder = '../../images/' + args['dataset_name'] + '/'

    # check if dataset folder exists:
    if not os.path.exists(dataset_folder):
        raise Exception("The dataset folder does not exist!")
    # dataset_folder = dataset_folder

    xacro_name = args['xacro_name']
    file_xacro = "../xacros/" + xacro_name

    if not os.path.exists(file_xacro):
        raise Exception("The xacro file does not exist! The xacro file must be inside the hpe/scripts/xacros folder.")
    transform_dict = get_transform_tree_dict(file_xacro)

    # Define cameras to use
    cameras = {}
    # if args["cameras_to_use"]:
    for camera in args["cameras_to_use"]:
        # check if sensor exists:
        if os.path.exists(dataset_folder + camera):
            cameras[camera] = {}
        else:
            raise Exception("The given sensor " + camera + " does not exist in the dataset!")

    # Define frames to use
    if not args['optimize_all']:
        idxs_to_use = args['frames_to_use']

    joints_to_use = []
    for _, link in links.items():  # derive al joints to use based on the parent and childs
        joints_to_use.append(link['parent'])
        joints_to_use.append(link['child'])

    joints_to_use = list(set(joints_to_use))  # remove repetitions

    # Paths to files
    for camera_key, camera in cameras.items():
        camera['folder'] = dataset_folder + camera_key + '/'
        camera['folder_skeleton'] = dataset_folder + 'cache/' + camera_key + '/'
        camera['files'] = natsorted(os.listdir(camera['folder']))
        camera['file_skeleton'] = camera['folder_skeleton'] + '2d_pose.npy'
        camera['frame_id'] = camera_key + '_rgb_optical_frame'

        # Check if frame exists
        if not args['optimize_all']:
            for idx in idxs_to_use:
                if not os.path.exists(camera['folder'] + str(idx) + '.png'):
                    raise Exception("The frame " + str(idx) + " does not exist in the dataset!")

        # TODO change to frame_id

    # Read camera intrinsics
    for camera_key, camera in cameras.items():
        yaml_folder = dataset_folder + camera_key + '_sim.yaml'
        with open(yaml_folder) as f:
            data = yaml.load(f, Loader=SafeLoader)
        camera['intrinsics'] = np.reshape(data['camera_matrix']['data'], (3, 3))
        camera['distortion'] = np.reshape(data['distortion_coefficients']['data'], (5, 1))
        camera['width'] = data['image_width']
        camera['height'] = data['image_height']

    # Read camera extrinsics
    for camera_key, camera in cameras.items():
        camera['extrinsics'] = getTransform(camera['frame_id'], world_frame, transform_dict)
        # print(camera['extrinsics'])
        # exit(0)
        # camera['extrinsics'] = getTransform(world_frame, camera['frame_id'], transform_dict)

    frames_total = 0
    first = True

    for _, camera in cameras.items():
        poses_np = np.load(camera['file_skeleton'], allow_pickle=True)
        frames_per_camera = 0
        for idx, (file, pose_np) in enumerate(zip(camera['files'], poses_np.tolist())):
            frames_per_camera += 1

        if first == True:
            frames_total = frames_per_camera
            first = False

        if frames_per_camera < frames_total:
            frames_total = frames_per_camera

    # Read images and files
    for _, camera in cameras.items():
        poses_np = np.load(camera['file_skeleton'], allow_pickle=True)
        camera['frames'] = {}
        for idx, (file, pose_np) in enumerate(zip(camera['files'], poses_np.tolist())):
            if not args['optimize_all']:
                if not idx in idxs_to_use:
                    continue

            if int(idx) > frames_total:
                continue

            # print('Reading file ' + file)
            frame = cv2.imread(camera['folder'] + file)

            joints = {}
            for joint_idx_key, joint_data_value in joint_idxs.items():
                joint_idx_value = joint_data_value['idx']
                if joint_idx_key not in joints_to_use:
                    continue

                point = pose_np[joint_idx_value]
                x, y = point[0], point[1]
                if len(point) == 3:
                    confidence = point[2]
                    valid = (not x == 0) and (not y == 0)
                    joints[joint_idx_key] = {'x': x, 'y': y, 'confidence': confidence, 'valid': valid,
                                             'x_proj': 0.0, 'y_proj': 0.0, 'color': joint_data_value['color']}
                else:
                    valid = (not x == 0) and (not y == 0)
                    joints[joint_idx_key] = {'x': x, 'y': y, 'valid': valid,
                                             'x_proj': 0.0, 'y_proj': 0.0, 'color': joint_data_value['color']}
                # TODO implement a better initialization

            camera['frames'][str(idx)] = {'image_path': camera['folder'] + file, 'image': frame, 'joints': joints}

    print(frames_total)
    # read ground truth
    if args['has_ground_truth']:
        ground_truth_path = dataset_folder + "ground_truth.txt"
        if os.path.exists(ground_truth_path):
            with open(ground_truth_path, "r") as fp:
                ground_truth = json.load(fp)
        else:
            raise Exception("The ground truth file does not exist!")

    # create the poses 3D dictionary.
    # NOTE must have the same size as the frames in the sensors
    selected_camera_key = list(cameras.keys())[0]  # select the first in the list arbitrarily
    selected_camera = cameras[selected_camera_key]
    # Since all the frames lists in the cameras are of the same size, we can use the first in the list.
    poses3d = {}
    for frame_key, frame in selected_camera['frames'].items():
        if int(frame_key) > frames_total:
            continue
        poses3d[frame_key] = {}
        for joint_key, joint in frame['joints'].items():
            poses3d[frame_key][joint_key] = {'X': 0.0, 'Y': 0.0, 'Z': 0.0, 'color': joint['color']}
            # poses3d[frame_key][joint_key] = intial_estimate_joints[joint_key]

    # print(camera['frames'].keys())

    # -----------------------------------------
    # OPTIMIZATION
    # -----------------------------------------
    # Configure optimizer
    opt = Optimizer()
    opt.addDataModel('args', args)
    opt.addDataModel('cameras', cameras)
    opt.addDataModel('poses3d', poses3d)

    # ----------------------------------------------
    # Setters and getters
    # ----------------------------------------------
    def getJointXYZ(data, frame_key, joint_key):
        d = data[frame_key][joint_key]
        X, Y, Z = d['X'], d['Y'], d['Z']
        return [X, Y, Z]

    def setJointXYZ(data, values, frame_key, joint_key):
        X, Y, Z = values[0], values[1], values[2]
        d = data[frame_key][joint_key]
        d['X'] = X
        d['Y'] = Y
        d['Z'] = Z

    # Test getters and setters
    # setJointXYZ(cameras, [33,44,55], 'camera_2', '200', 'RWrist')
    # pt = getJointXYZ(cameras, 'camera_2', '200', 'RShoulder')
    # print(pt)
    # pt = getJointXYZ(cameras, 'camera_3', '200', 'RShoulder')
    # print(pt)

    # ----------------------------------------------
    # Create the parameters
    # ----------------------------------------------
    for frame_key, frame in selected_camera['frames'].items():
        for joint_key, joint in frame['joints'].items():
            group_name = 'frame_' + frame_key + '_joint_' + joint_key
            opt.pushParamVector(group_name, data_key='poses3d',
                                getter=partial(getJointXYZ, frame_key=frame_key, joint_key=joint_key),
                                setter=partial(setJointXYZ, frame_key=frame_key, joint_key=joint_key),
                                suffix=['_X', '_Y', '_Z'])

    # opt.printParameters()
    # ----------------------------------------------
    # Define the objective function
    # ----------------------------------------------
    opt.setObjectiveFunction(objectiveFunction)
    # ----------------------------------------------
    # Define the residuals
    # ----------------------------------------------

    # Projection residuals
    for camera_key, camera in cameras.items():
        for frame_key, frame in camera['frames'].items():
            for joint_key, joint in frame['joints'].items():
                if not joint['valid']:  # skip invalid joints
                    continue

                parameter_pattern = 'frame_' + frame_key + '_joint_' + joint_key
                residual_key = 'projection_sensor_' + camera_key + '_frame_' + frame_key + '_joint_' + joint_key

                params = opt.getParamsContainingPattern(pattern=parameter_pattern)  # get all weight related parameters
                opt.pushResidual(name=residual_key, params=params)

    # Frame distance residuals
    if not args['skip_frame_to_frame_residuals']:
        for initial_frame_key, final_frame_key in zip(list(poses3d.keys())[:-1], list(poses3d.keys())[1:]):
            if not int(final_frame_key) - int(initial_frame_key) == 1:  # frames are not consecutive
                # print('Frame ' + initial_frame_key + ' and ' + final_frame_key + ' are not consecutive.')
                continue

            initial_pose = poses3d[initial_frame_key]
            # final_pose = poses3d[final_frame_key]

            for joint_key in initial_pose.keys():
                residual_key = 'consecutive_frame_' + initial_frame_key + '_frame_' + final_frame_key + '_joint_' + joint_key

                params = opt.getParamsContainingPattern(
                    pattern='frame_' + initial_frame_key + '_joint_' + joint_key + '_')  # get parameters for frame initial_frame
                params.extend(opt.getParamsContainingPattern(
                    pattern='frame_' + final_frame_key + '_joint_' + joint_key + '_'))  # get parameters for frame final_frame
                opt.pushResidual(name=residual_key, params=params)

    # Link length residuals
    if not args['skip_link_length_residuals']:
        for link_key, link in links.items():
            print('Link ' + link['parent'] + '-' + link['child'])

            for frame_key, frame in poses3d.items():  # compute residual as distance from reference link length

                params = opt.getParamsContainingPattern(pattern='frame_' + frame_key + '_joint_' + link['parent'] + '_')
                params.extend(
                    opt.getParamsContainingPattern(pattern='frame_' + frame_key + '_joint_' + link['child'] + '_'))

                residual_key = 'length_frame_' + frame_key + '_joint_' + link['parent'] + '_joint_' + link['child']
                opt.pushResidual(name=residual_key, params=params)

    opt.printResiduals()

    # ----------------------------------------------
    # Compute sparse matrix
    # ----------------------------------------------
    opt.computeSparseMatrix()
    opt.printSparseMatrix()

    # ----------------------------------------------
    # Visualization function
    # ----------------------------------------------
    # Setup visualization
    if args['show_images']:
        # Create annotated images (for debug)
        for _, camera in cameras.items():
            for frame_idx, (_, frame) in enumerate(camera['frames'].items()):
                frame['image_gui'] = deepcopy(frame['image'])
                for link_name, link in links.items():
                    joint0 = frame['joints'][link['parent']]
                    joint1 = frame['joints'][link['child']]

                    if not joint0['valid'] or not joint1['valid']:
                        continue

                    x0, y0 = int(joint0['x']), int(joint0['y'])
                    x1, y1 = int(joint1['x']), int(joint1['y'])

                    cv2.line(frame['image_gui'], (x0, y0), (x1, y1), (128, 128, 128), 3)

        for _, camera in cameras.items():
            for _, frame in camera['frames'].items():
                for _, joint in frame['joints'].items():
                    if not joint['valid']:
                        continue

                    x, y = int(joint['x']), int(joint['y'])
                    color = joint['color']

                    square_size = 5 + (20 - 5) * joint['confidence']
                    drawSquare2D(frame['image_gui'], x, y, square_size, color=joint['color'], thickness=3)

        if args['has_ground_truth']:
            for _, camera in cameras.items():
                for frame_key, frame in camera['frames'].items():
                    for joint_key, joint in ground_truth[frame_key].items():
                        # if not joint['valid']:
                        #     continue

                        pts_in_world = np.ndarray((4, 1), dtype=float)
                        pts_in_world[0][0] = ground_truth[frame_key][joint_key]['pose']['x']
                        pts_in_world[1][0] = ground_truth[frame_key][joint_key]['pose']['y']
                        pts_in_world[2][0] = ground_truth[frame_key][joint_key]['pose']['z']
                        pts_in_world[3][0] = 1

                        pts_in_sensor = np.dot(camera['extrinsics'], pts_in_world)

                        # Project 3D point to camera
                        pts_in_image, _, _ = projectToCamera(np.array(camera['intrinsics']),
                                                             np.array(camera['distortion']),
                                                             camera['width'],
                                                             camera['height'], pts_in_sensor[0:3, :])

                        xpix_projected = pts_in_image[0][0]
                        ypix_projected = pts_in_image[1][0]

                        x, y = int(xpix_projected), int(ypix_projected)
                        color = (0, 0, 0)

                        # square_size = 5 + (20 - 5) * joint['confidence']
                        drawDiagonalCross2D(frame['image_gui'], x, y, 10, color=color, thickness=3)

        for camera_key, camera in cameras.items():
            for frame_key, frame in camera['frames'].items():
                window_name = camera_key + '_' + frame_key
                # cv2.namedWindow(window_name, cv2.WINDOW_FULLSCREEN)
                cv2.namedWindow(window_name, cv2.WINDOW_NORMAL)
                cv2.imshow(window_name, frame['image_gui'])

        plot_handles = {}
        for frame_key, frame in selected_camera['frames'].items():
            plot_handles[frame_key] = {}
            plot_handles[frame_key]['figure'] = plt.figure()
            plot_handles[frame_key]['figure'].suptitle('Frame #' + str(frame_key), fontsize=14)
            plot_handles[frame_key]['axes'] = plot_handles[frame_key]['figure'].add_subplot(111, projection='3d')
            plot_handles[frame_key]['axes'].set_xlim3d(-1, 1)
            plot_handles[frame_key]['axes'].set_ylim3d(-3, -1)
            plot_handles[frame_key]['axes'].set_zlim3d(0, 2)
            plot_handles[frame_key]['axes'].set_xlabel('x')
            plot_handles[frame_key]['axes'].set_ylabel('y')
            plot_handles[frame_key]['axes'].set_zlabel('z')

        figures = []
        for plot_handle_key, plot_handle in plot_handles.items():
            figures.append(plot_handle['figure'])

        wm = WindowManager(figs=figures)
        opt.addDataModel('window_manager', wm)

        for frame_key, frame in selected_camera['frames'].items():
            X_vec = []
            Y_vec = []
            Z_vec = []
            joint_colors = []
            for joint_key, joint in frame['joints'].items():
                point = poses3d[frame_key][joint_key]
                X_vec.append(point['X'])
                Y_vec.append(point['Y'])
                Z_vec.append(point['Z'])
                b, g, r = tuple(c / 255 for c in point['color'])
                joint_colors.append((r, g, b))

            # ground truth
            if args['has_ground_truth']:
                for joint_key, joint in ground_truth[frame_key].items():
                    point = ground_truth[frame_key][joint_key]['pose']
                    X_vec.append(float(point['x']))
                    Y_vec.append(float(point['y']))
                    Z_vec.append(float(point['z']))
                    b, g, r = (0, 0, 0)
                    joint_colors.append((r, g, b))

            plot_handles[frame_key]['joint_handle'] = plot_handles[frame_key]['axes'].scatter(X_vec, Y_vec, Z_vec,
                                                                                              c=joint_colors)  # Draw N points

        for frame_key, frame in selected_camera['frames'].items():
            # ground_truth
            if args['has_ground_truth']:
                plot_handles[frame_key]['ground_truth'] = {}
                for link_name, link in links_ground_truth.items():
                    plot_handles[frame_key]['ground_truth'][link_name] = {}
                    joint0 = ground_truth[frame_key][link['parent']]['pose']
                    joint1 = ground_truth[frame_key][link['child']]['pose']
                    X0 = joint0['x']
                    Y0 = joint0['y']
                    Z0 = joint0['z']
                    X1 = joint1['x']
                    Y1 = joint1['y']
                    Z1 = joint1['z']

                    plot_handles[frame_key]['ground_truth'][link_name]['link_handle_gt'] = plot_handles[frame_key][
                        'axes'].plot([X0, X1], [Y0, Y1],
                                     [Z0, Z1],
                                     c=(0, 0, 0))
                # ground_truth
            # print(plot_handles[frame_key]['ground_truth'])
            for link_name, link in links.items():
                plot_handles[frame_key][link_name] = {}
                joint0 = poses3d[frame_key][link['parent']]
                joint1 = poses3d[frame_key][link['child']]
                X0 = joint0['X']
                Y0 = joint0['Y']
                Z0 = joint0['Z']
                X1 = joint1['X']
                Y1 = joint1['Y']
                Z1 = joint1['Z']

                plot_handles[frame_key][link_name]['link_handle'] = plot_handles[frame_key]['axes'].plot([X0, X1],
                                                                                                         [Y0, Y1],
                                                                                                         [Z0, Z1],
                                                                                                         c=(
                                                                                                             0.5, 0.5,
                                                                                                             0.5))
        plt.draw()
        plt.pause(0.0001)

    # cv2.waitKey(0)

    def visualization_function(data):
        cameras = data['cameras']
        poses3d = data['poses3d']
        wm = data['window_manager']
        # print('Visualization function called')

        for camera_key, camera in cameras.items():
            for frame_key, frame in camera['frames'].items():
                image = deepcopy(frame['image_gui'])
                for joint_key, joint in frame['joints'].items():
                    x = joint['x_proj']
                    y = joint['y_proj']
                    drawCross2D(image, x, y, 15, color=joint['color'], thickness=3)
                    window_name = camera_key + '_' + frame_key
                    cv2.imshow(window_name, image)

        for frame_key, frame in selected_camera['frames'].items():
            X_vec = []
            Y_vec = []
            Z_vec = []
            for joint_key, joint in frame['joints'].items():
                point = poses3d[frame_key][joint_key]
                X_vec.append(point['X'])
                Y_vec.append(point['Y'])
                Z_vec.append(point['Z'])

            if args['has_ground_truth']:
                for joint_key, joint in ground_truth[frame_key].items():
                    point = ground_truth[frame_key][joint_key]['pose']
                    X_vec.append(float(point['x']))
                    Y_vec.append(float(point['y']))
                    Z_vec.append(float(point['z']))

            plot_handles[frame_key]['joint_handle']._offsets3d = (X_vec, Y_vec, Z_vec)
            plt.draw()

            for link_name, link in links.items():
                joint0 = poses3d[frame_key][link['parent']]
                joint1 = poses3d[frame_key][link['child']]

                X0 = joint0['X']
                Y0 = joint0['Y']
                Z0 = joint0['Z']
                X1 = joint1['X']
                Y1 = joint1['Y']
                Z1 = joint1['Z']

                plot_handles[frame_key][link_name]['link_handle'][0].set_xdata([X0, X1])
                plot_handles[frame_key][link_name]['link_handle'][0].set_ydata([Y0, Y1])
                plot_handles[frame_key][link_name]['link_handle'][0].set_3d_properties(zs=[Z0, Z1])

            if args['has_ground_truth']:
                for link_name, link in links_ground_truth.items():
                    joint0 = ground_truth[frame_key][link['parent']]['pose']
                    joint1 = ground_truth[frame_key][link['child']]['pose']
                    X0 = joint0['x']
                    Y0 = joint0['y']
                    Z0 = joint0['z']
                    X1 = joint1['x']
                    Y1 = joint1['y']
                    Z1 = joint1['z']
                    # exit(0)

                    plot_handles[frame_key]['ground_truth'][link_name]['link_handle_gt'][0].set_xdata([X0, X1])
                    plot_handles[frame_key]['ground_truth'][link_name]['link_handle_gt'][0].set_ydata([Y0, Y1])
                    plot_handles[frame_key]['ground_truth'][link_name]['link_handle_gt'][0].set_3d_properties(
                        zs=[Z0, Z1])

            plt.draw()
            # plt.pause(0.00001)
        # cv2.waitKey(10)
        wm.waitForKey(time_to_wait=0.001, verbose=False, message='')

    if args['show_images']:
        opt.setVisualizationFunction(visualization_function, always_visualize=True, niterations=1)
        opt.internal_visualization = False

    # ----------------------------------------------
    # Start the xacros
    # ----------------------------------------------
    if args['phased_execution'] and args['show_images']:
        wm.waitForKey(verbose=True, message='Ready to start optimization. Press \'c\' to continue.')

    # opt.callObjectiveFunction() # just to test
    opt.startOptimization(optimization_options={'x_scale': 'jac', 'ftol': 1e-7,
                                                'xtol': 1e-7, 'gtol': 1e-7, 'diff_step': None})  # , 'max_nfev': 1

    if args['save_output']:
        print("Saving output files...")
        for camera_key, camera in cameras.items():
            cameras[camera_key]['intrinsics'] = cameras[camera_key]['intrinsics'].tolist()
            cameras[camera_key]['extrinsics'] = cameras[camera_key]['extrinsics'].tolist()
            cameras[camera_key]['distortion'] = cameras[camera_key]['distortion'].tolist()
            for frame_key, frame in camera['frames'].items():
                # print(frame_key)
                cameras[camera_key]['frames'][frame_key]['image'] = cameras[camera_key]['frames'][frame_key][
                    'image'].tolist()
                del cameras[camera_key]['frames'][frame_key]['image_gui']
                del cameras[camera_key]['frames'][frame_key]['image']
        with open(dataset_folder + "poses3d.json", "w") as fp:
            json.dump(poses3d, fp)
        with open(dataset_folder + "cameras.json", "w") as fp:
            json.dump(cameras, fp)

    if args['phased_execution'] and args['show_images']:
        wm.waitForKey(verbose=True, message='Optimization finished. Press \'q\' to quit.')


if __name__ == "__main__":
    main()
