import math
import numpy as np
from numba import njit
from utils.projection import projectToCamera_faster, batch_projectToCamera
from utils.links import links_ground_truth, links_openpose, joint_idxs, joints_mediapipe, links_mediapipe, \
    joint_idxs_op_mpi, links_op_mpi, links_mpi_inf_3dhp, joints_mpi_inf_3dhp
from pytictoc import TicToc
from prettytable import PrettyTable
from colorama import Style, Fore


@njit
def pointsDistance(xi, yi, zi, xf, yf, zf):
    return math.sqrt((xf - xi) ** 2 + (yf - yi) ** 2 + (zf - zi) ** 2)


@njit
def update_joints_and_calculate_residuals(pts_in_image_frame, joints_data, joints_valid, joints_confidence):
    residuals = np.zeros(len(joints_data))
    for idx in range(len(joints_data)):
        x_proj = pts_in_image_frame[0, 0, idx]
        y_proj = pts_in_image_frame[1, 0, idx]

        if joints_valid[idx]:
            residuals[idx] = joints_confidence[idx] * 2 * math.sqrt(
                (joints_data[idx][0] - x_proj) ** 2 + (joints_data[idx][1] - y_proj) ** 2)
    return residuals


def objectiveFunction(data):
    args = data['args']
    cameras = data['cameras']
    poses3d = data['poses3d']
    groundtruth = data['groundtruth']
    frame_list = data['frame_list']
    e = data['errors']
    proj_time = 0
    ll_time = 0
    ff_time = 0

    residuals = {}
    mpi = False
    human36m = False

    if args['dataset_name'] == "mpi":
        mpi = True
    elif args['dataset_name'] == "mpi":
        human36m = True

    detector = 'groundtruth'
    if args['2d_detector'] == "mediapipe":
        detector = 'mediapipe'
        links = links_mediapipe
        skeleton_joints = joints_mediapipe
    elif args['2d_detector'] == "openpose":
        detector = 'openpose'
        links = links_openpose
        skeleton_joints = joint_idxs
        if mpi == True:
            skeleton_joints = joint_idxs_op_mpi
            links = links_op_mpi
    else:
        links = links_mpi_inf_3dhp
        skeleton_joints = joints_mpi_inf_3dhp

    joints_to_use = []
    for _, link in links.items():  # derive al joints to use based on the parent and childs
        joints_to_use.append(link['parent'])
        joints_to_use.append(link['child'])

    joints_to_use = list(set(joints_to_use))  # remove repetitions

    # Projection residuals ------------------------------------------------
    t = TicToc()  # create instance of class
    t.tic()
    iteration = data['status']['num_iterations']

    # Precompute joint coordinates for all frames
    pts_in_world_all_frames = {}
    for camera_key, camera in cameras.items():
        pts_in_world_all_frames[camera_key] = []
        for frame_key in frame_list:
            # frame_key = str(frame_key)
            frame = camera['frames'][frame_key]
            pts_in_world_frame = np.zeros((4, len(frame['joints'])), dtype=float)
            for idx, (joint_key, joint) in enumerate(frame['joints'].items()):
                pts_in_world_frame[0, idx] = poses3d[frame_key][joint_key]['X']
                pts_in_world_frame[1, idx] = poses3d[frame_key][joint_key]['Y']
                pts_in_world_frame[2, idx] = poses3d[frame_key][joint_key]['Z']
                pts_in_world_frame[3, idx] = 1
            pts_in_world_all_frames[camera_key].append(pts_in_world_frame)

    t1 = t.tocvalue()
    t.tic()

    # t_small = TicToc()  # create instance of class

    # Perform batch processing for each camera
    for camera_key, camera in cameras.items():
        # t_small.tic()
        pts_in_world_batch = np.concatenate(pts_in_world_all_frames[camera_key], axis=1)  # Concatenate all frames
        pts_in_sensor = np.dot(camera['extrinsics'], pts_in_world_batch)
        # t_mult = t_small.tocvalue()
        # t_small.tic()

        # Project all 3D points to camera in one go
        pts_in_image_batch = batch_projectToCamera(camera['intrinsics'], camera['distortion'], camera['width'],
                                                   camera['height'], pts_in_sensor[0:3, :])
        # t_proj = t_small.tocvalue()
        # t_small.tic()

        # Update joint projections in each frame
        for frame_idx, frame_key in enumerate(frame_list):
            frame = camera['frames'][frame_key]
            pts_in_image_frame = pts_in_image_batch[:, :,
                                 frame_idx * len(frame['joints']): (frame_idx + 1) * len(frame['joints'])]

            # Prepare data for Numba function
            joints_data = []
            joints_valid = []
            joints_confidence = []
            for joint_key, joint in frame['joints'].items():
                joints_data.append((joint['x'], joint['y']))
                joints_valid.append(joint['valid'])
                joints_confidence.append(joint['confidence'])

            joints_data = np.array(joints_data)
            joints_valid = np.array(joints_valid)
            joints_confidence = np.array(joints_confidence)

            frame_residuals = update_joints_and_calculate_residuals(pts_in_image_frame, joints_data, joints_valid,
                                                                    joints_confidence)
            for idx, (joint_key, joint) in enumerate(frame['joints'].items()):
                if joints_valid[idx]:
                    residual_key = 'projection_sensor_' + camera_key + '_frame_' + frame_key + '_joint_' + joint_key
                    residuals[residual_key] = frame_residuals[idx]

        # t_res = t_small.tocvalue()

        # t_tot = t_mult + t_proj + t_res
        #
        # print("Multiplication took " + str(t_mult) + "s and " + str(
        #     round(t_mult / t_tot * 100, 2)) + "%")
        # print("Project to camera took " + str(t_proj) + "s and " + str(round(t_proj / t_tot * 100, 2)) + "%")
        # print("Residual assignment took " + str(t_res) + "s and " + str(round(t_res / t_tot * 100, 2)) + "%")

    # t2 = t.tocvalue()
    # t.tic()

    # if args['print_time'] and data['status']['is_iteration']:
    #     total_time = proj_time + ll_time + ff_time
    #     proj_time = t.tocvalue()
        # proj_time = t1 + t2
        # print("Structure creation took " + str(t1) + "s and " + str(
        #     round(t1 / proj_time * 100, 2)) + "%")
        # print("Multiplication and project to camera took " + str(t2) + "s and " + str(round(t2 / proj_time * 100, 2)) + "%")

    # Consecutive frame distance ------------------------------------------------
    # t.tic()
    if not args['skip_frame_to_frame_residuals']:
        for initial_frame_key, final_frame_key in zip(list(poses3d.keys())[:-1], list(poses3d.keys())[1:]):
            if not int(final_frame_key) - int(initial_frame_key) == 1:  # frames are not consecutive
                continue

            if final_frame_key not in frame_list:
                continue

            initial_pose = poses3d[initial_frame_key]
            final_pose = poses3d[final_frame_key]

            for joint_key in initial_pose.keys():
                initial_joint = initial_pose[joint_key]
                final_joint = final_pose[joint_key]

                residual_key = 'consecutive_frame_' + initial_frame_key + '_frame_' + final_frame_key + '_joint_' + joint_key
                xi, yi, zi = initial_joint['X'], initial_joint['Y'], initial_joint['Z']
                xf, yf, zf = final_joint['X'], final_joint['Y'], final_joint['Z']
                dist = pointsDistance(xi, yi, zi, xf, yf, zf)
                # residuals[residual_key] = math.exp(pointsDistance(xi, yi, zi, xf, yf, zf))
                # residuals[residual_key] = math.sqrt(dist)
                invalid_joint = False
                for camera_key, camera in cameras.items():
                    initial_frame = camera['frames'][initial_frame_key]
                    final_frame = camera['frames'][final_frame_key]
                    initial_joint = initial_frame['joints'][joint_key]
                    final_joint = final_frame['joints'][joint_key]
                    if not initial_joint['valid'] or not final_joint['valid']:
                        # print(joint_key)
                        invalid_joint = True
                        break

                if pointsDistance(xi, yi, zi, xf, yf, zf) > 50 or invalid_joint == True:
                    residuals[residual_key] = pointsDistance(xi, yi, zi, xf, yf, zf)
                else:
                    residuals[residual_key] = 0
    # ff_time = t.tocvalue()

    # Link length residuals  ----------------------------------------------------
    t.tic()

    if not args['skip_link_length_residuals']:
        for link_key, link in links.items():
            link_lengths = []
            for frame_key in frame_list:
                frame = poses3d[frame_key]

                # Measured link length
                parent_pose = frame[link['parent']]
                Xp, Yp, Zp = parent_pose['X'], parent_pose['Y'], parent_pose['Z']

                child_pose = frame[link['child']]
                Xc, Yc, Zc = child_pose['X'], child_pose['Y'], child_pose['Z']

                length = pointsDistance(Xp, Yp, Zp, Xc, Yc, Zc)
                link_lengths.append(length)

            link_length_average = np.mean(link_lengths)

            variance = 0
            for n_frames, _ in enumerate(frame_list):
                link_length = link_lengths[n_frames]
                variance += abs(link_length - link_length_average)

            link_length_residual = variance / len(frame_list)
            residual_key = 'link_length_' + 'joint_' + link['parent'] + '_joint_' + link['child']
            residuals[residual_key] = link_length_residual

    # if args['print_time'] and data['status']['is_iteration']:
    #     ll_time = t.tocvalue()
    #     total_time = proj_time + ll_time + ff_time
    #     print("Projection residuals took " + str(proj_time) + "s and " + str(
    #         round(proj_time / total_time * 100, 2)) + "%")
    #     print("Link length residuals took " + str(ll_time) + "s and " + str(round(ll_time / total_time * 100, 2)) + "%")
    #     print("Frame to frame residuals took " + str(ff_time) + "s and " + str(
    #         round(ff_time / total_time * 100, 2)) + "%")

    # if args['debug'] and data['status']['is_iteration'] and not args['skip_link_length_residuals']:
    #     print(table)
    #
    # if iteration == 0 and args['debug'] and not args['skip_link_length_residuals']:
    #     print(table_gt)
    # # exit(0)

    if args['debug'] and data['status']['is_iteration']:
        iteration = data['status']['num_iterations']
        table_header = ['Frame #']
        for joint_key in joints_to_use:
            table_header.append(joint_key)

        # table_header.append('Average')
        table = PrettyTable(table_header)

        e[iteration] = {}
        for frame_key, frame in poses3d.items():
            e[iteration][frame_key] = {}
            avg_row = [frame_key]
            for joint_key in joints_to_use:
                joint_detected = frame[joint_key]

                e[iteration][frame_key][joint_key] = {}

                x_det = joint_detected['X']
                y_det = joint_detected['Y']
                z_det = joint_detected['Z']

                joint_gt = groundtruth[frame_key][joint_key]['pose']

                x_gt = joint_gt['x']
                y_gt = joint_gt['y']
                z_gt = joint_gt['z']

                rmse = abs(math.dist([x_det, y_det, z_det], [x_gt, y_gt, z_gt]))

                e[iteration][frame_key][joint_key]['rmse'] = rmse

                avg_row.append('%.1f' % rmse)
            table.add_row(avg_row)
        print(table)

        # t.toc('Link length residuals took ')

    return residuals
