import math

import numpy as np
from utils.projection import projectToCamera
from utils.links import links_ground_truth, links_openpose, joint_idxs, joints_mediapipe, links_mediapipe, \
    joint_idxs_op_mpi, links_op_mpi, links_mpi_inf_3dhp, joints_mpi_inf_3dhp
from pytictoc import TicToc
from numba import jit, cuda

from prettytable import PrettyTable
from colorama import Style, Fore


def pointsDistance(xi, yi, zi, xf, yf, zf):
    return math.sqrt((xf - xi) ** 2 + (yf - yi) ** 2 + (zf - zi) ** 2)


def average(lst):
    return float(sum(lst)) / float(len(lst))


def objectiveFunction(data):
    args = data['args']
    cameras = data['cameras']
    poses3d = data['poses3d']
    joints_to_use=data['joints_to_use']
    frames_to_use=data['frames_to_use']
    if args['has_ground_truth']:
        groundtruth = data['groundtruth']
    e = data['errors']

    # print(cameras)
    # iteration = data['iteration']

    residuals = {}
    # print('Objective function called')
    mpi = False
    human36m = False

    if args['dataset_name'] == "mpi":
        mpi = True
    elif args['dataset_name'] == "mpi":
        human36m = True

    detector = 'groundtruth'
    # if args['2d_detector'] == "mediapipe":
    #     detector = 'mediapipe'
    #     links = links_mediapipe
    #     skeleton_joints = joints_mediapipe
    # elif args['2d_detector'] == "openpose":
    #     detector = 'openpose'
    #     links = links_openpose
    #     skeleton_joints = joint_idxs
    #     if mpi == True:
    #         skeleton_joints = joint_idxs_op_mpi
    #         links = links_op_mpi
    # else:
    #     links = links_mpi_inf_3dhp
    #     skeleton_joints = joints_mpi_inf_3dhp

    if args['2d_detector'] == "mediapipe":
        links = links_mediapipe
        skeleton_joints = joints_mediapipe
    elif args['2d_detector'] == "openpose":
        links = links_openpose
        skeleton_joints = joint_idxs_op_mpi if mpi else joint_idxs
    elif args['2d_detector'] == "mpi":
        links = links_mpi_inf_3dhp
        skeleton_joints = joints_mpi_inf_3dhp
    else:
        raise ValueError(f"Unknown 2D detector: {args['2d_detector']}")

    # joints_to_use = []
    # for _, link in links.items():  # derive al joints to use based on the parent and childs
    #     joints_to_use.append(link['parent'])
    #     joints_to_use.append(link['child'])

    joints_to_use = list(set(joints_to_use))  # remove repetitions

    # Projection residuals ------------------------------------------------
    # t = TicToc()  # create instance of class
    # t.tic()

    for camera_key, camera in cameras.items():
        # for frame_key, frame in camera['frames'].items():
        for frame_key in frames_to_use:
            frame = camera['frames'][str(frame_key)]
            # print(frame['joints '])
            for joint_key in joints_to_use:
                if args['2d_detector'] == "mpi":
                    joint_idx = joints_mpi_inf_3dhp[joint_key]['idx']
                else:
                    joint_idx=joints_mediapipe[joint_key]['idx']
                joint = frame['joints'][joint_idx]

                # get 3d point coordinates
                pts_in_world = np.ndarray((4, 1), dtype=float)
                pts_in_world[0][0] = poses3d[int(frame_key)][joint_key]['X']
                pts_in_world[1][0] = poses3d[int(frame_key)][joint_key]['Y']
                pts_in_world[2][0] = poses3d[int(frame_key)][joint_key]['Z']
                pts_in_world[3][0] = 1

                # Transform to the camera's coordinate frame
                pts_in_sensor = np.dot(camera['extrinsics'], pts_in_world)

                # Project 3D point to camera
                pts_in_image, _, _ = projectToCamera(camera['intrinsics'], camera['distortion'], camera['width'],
                                                     camera['height'], pts_in_sensor[0:3, :])

                # Compute distance from 2D detection of joint and 3D->2D projection
                xpix_projected = pts_in_image[0][0]
                ypix_projected = pts_in_image[1][0]
                joint['x_proj'] = pts_in_image[0][0]
                joint['y_proj'] = pts_in_image[1][0]

                if joint['valid']:
                    xpix_detected = joint['x']
                    ypix_detected = joint['y']

                    residual_key = 'projection_sensor_' + camera_key + '_frame_' + str(frame_key) + '_joint_' + joint_key
                    residuals[residual_key] = joint['confidence'] * 2 * math.sqrt(
                        (xpix_detected - xpix_projected) ** 2 + (ypix_detected - ypix_projected) ** 2)

    # Consecutive frame distance ------------------------------------------------
    # t.tic()
    if not args['skip_frame_to_frame_residuals']:
        for initial_frame_key, final_frame_key in zip(list(poses3d.keys())[:-1], list(poses3d.keys())[1:]):
            if not int(final_frame_key) - int(initial_frame_key) == 1:  # frames are not consecutive
                continue

            initial_pose = poses3d[initial_frame_key]
            final_pose = poses3d[final_frame_key]

            for joint_key in initial_pose.keys():
                initial_joint = initial_pose[joint_key]
                final_joint = final_pose[joint_key]

                residual_key = 'consecutive_frame_' + initial_frame_key + '_frame_' + final_frame_key + '_joint_' + joint_key
                xi, yi, zi = initial_joint['X'], initial_joint['Y'], initial_joint['Z']
                xf, yf, zf = final_joint['X'], final_joint['Y'], final_joint['Z']
                # residuals[residual_key] = math.exp(pointsDistance(xi, yi, zi, xf, yf, zf))*10

                invalid_joint = False
                for camera_key, camera in cameras.items():
                    initial_frame = camera['frames'][initial_frame_key]
                    final_frame = camera['frames'][final_frame_key]
                    initial_joint = initial_frame['joints'][joint_key]
                    final_joint = final_frame['joints'][joint_key]
                    if not initial_joint['valid'] or not final_joint['valid']:
                        # print(joint_key)
                        invalid_joint = True
                        break

                if pointsDistance(xi, yi, zi, xf, yf, zf) > 150 or invalid_joint == True:
                    residuals[residual_key] = pointsDistance(xi, yi, zi, xf, yf, zf)
                else:
                    residuals[residual_key] = 0

    # if args['debug']:
    #     t.toc('Consecutive frame residuals took ')

    # Link length residuals  ----------------------------------------------------
    # t.tic()
    iteration = data['status']['num_iterations']

    if not args['skip_link_length_residuals']:
        if args['debug']:
            table_header = ['Iteration #', 'Frame', 'Parent', 'Child', 'Groundtruth', 'Measured', 'Error',
                            'Std Deviation']
            table = PrettyTable(table_header)

        # table_header_gt = ['Frame #']
        # for link_key, _ in links.items():
        #     table_header_gt.append(link_key)

        # table_gt = PrettyTable(table_header_gt)

        # if iteration == 0:
        #     for frame_key, frame in poses3d.items():  # compute the average link length
        #         avg_row = [frame_key]
        #         for link_key, link in links.items():
        #             parent_gt = groundtruth[frame_key][link['parent']]['pose']
        #             Xp, Yp, Zp = parent_gt['x'], parent_gt['y'], parent_gt['z']
        #
        #             child_gt = groundtruth[frame_key][link['child']]['pose']
        #             Xc, Yc, Zc = child_gt['x'], child_gt['y'], child_gt['z']
        #
        #             avg_row.append('%.1f' % pointsDistance(Xp, Yp, Zp, Xc, Yc, Zc))
        #
        #         # print(avg_row)
        #         table_gt.add_row(avg_row)

        for link_key, link in links.items():
            link_lengths = []
            groundtruth_lengths = []
            for frame_key, frame in poses3d.items():  # compute the average link length

                # ground truth link length
                parent_gt = groundtruth[frame_key][link['parent']]['pose']
                Xp, Yp, Zp = parent_gt['x'], parent_gt['y'], parent_gt['z']

                child_gt = groundtruth[frame_key][link['child']]['pose']
                Xc, Yc, Zc = child_gt['x'], child_gt['y'], child_gt['z']

                groundtruth_lengths.append(pointsDistance(Xp, Yp, Zp, Xc, Yc, Zc))

                # print("Frame #" + str(frame_key) + '; link '+ str(link_key) + ': ' + str(pointsDistance(Xp, Yp, Zp, Xc, Yc, Zc)))

                # Measured link length
                parent_pose = frame[link['parent']]
                Xp, Yp, Zp = parent_pose['X'], parent_pose['Y'], parent_pose['Z']

                child_pose = frame[link['child']]
                Xc, Yc, Zc = child_pose['X'], child_pose['Y'], child_pose['Z']

                length = pointsDistance(Xp, Yp, Zp, Xc, Yc, Zc)
                link_lengths.append(length)

            link_length_average = average(link_lengths)
            groundtruth_lengths = average(groundtruth_lengths)

            link_lengths_measured = []
            stdev = 0
            n_frames = 0
            for frame_key, frame in poses3d.items():  # compute residual as distance from reference link length

                parent_pose = frame[link['parent']]
                Xp, Yp, Zp = parent_pose['X'], parent_pose['Y'], parent_pose['Z']
                child_pose = frame[link['child']]
                Xc, Yc, Zc = child_pose['X'], child_pose['Y'], child_pose['Z']

                link_length = pointsDistance(Xp, Yp, Zp, Xc, Yc, Zc)
                link_lengths_measured.append(link_length)

                stdev += abs(link_length - link_length_average)
                n_frames += 1

            # link_length_residual = average(link_lengths_measured)
            link_length_residual = stdev / n_frames

            if args['debug']:
                row = [str(iteration)]
                row.append(Fore.BLUE + Style.BRIGHT + 'AVERAGE ' + Style.RESET_ALL)
                row.append(link['parent'])
                row.append(link['child'])
                row.append('%.1f' % groundtruth_lengths)
                row.append('%.1f' % average(link_lengths_measured))
                row.append('%.1f' % abs(groundtruth_lengths - average(link_lengths_measured)))
                row.append('%.1f' % abs(link_length_residual))
                table.add_row(row)

            residual_key = 'link_length_' + 'joint_' + link['parent'] + '_joint_' + link['child']
            residuals[residual_key] = link_length_residual

            # if data['status']['num_iterations'] > 15:
            #     residuals[residual_key] = link_length_residual
            # else:
            #     residuals[residual_key] = 0

    if args['debug'] and data['status']['is_iteration'] and not args['skip_link_length_residuals']:
        print(table)

    # if iteration == 0 and args['debug'] and not args['skip_link_length_residuals']:
    #     print(table_gt)

    if args['debug'] and data['status']['is_iteration']:

        iteration = data['status']['num_iterations']
        table_header = ['Frame #']
        for joint_key in joints_to_use:
            table_header.append(joint_key)

        # table_header.append('Average')
        table = PrettyTable(table_header)

        e[iteration] = {}
        for frame_key, frame in poses3d.items():
            e[iteration][frame_key] = {}
            avg_row = [frame_key]
            for joint_key in joints_to_use:
                joint_detected = frame[joint_key]

                e[iteration][frame_key][joint_key] = {}

                x_det = joint_detected['X']
                y_det = joint_detected['Y']
                z_det = joint_detected['Z']

                joint_gt = groundtruth[frame_key][joint_key]['pose']

                x_gt = joint_gt['x']
                y_gt = joint_gt['y']
                z_gt = joint_gt['z']

                rmse = abs(math.dist([x_det, y_det, z_det], [x_gt, y_gt, z_gt]))

                e[iteration][frame_key][joint_key]['rmse'] = rmse

                avg_row.append('%.1f' % rmse)
            table.add_row(avg_row)
        print(table)

        # t.toc('Link length residuals took ')

    return residuals
