import math

import numpy as np
from utils.projection import projectToCamera, projectToCamera_faster
from utils.links import links_ground_truth, links_openpose, joint_idxs, joints_mediapipe, links_mediapipe, \
    joint_idxs_op_mpi, links_op_mpi, links_mpi_inf_3dhp, joints_mpi_inf_3dhp
from pytictoc import TicToc

from prettytable import PrettyTable
from colorama import Style, Fore


def pointsDistance(xi, yi, zi, xf, yf, zf):
    return math.sqrt((xf - xi) ** 2 + (yf - yi) ** 2 + (zf - zi) ** 2)


def average(lst):
    return float(sum(lst)) / float(len(lst))


def objectiveFunction(data):
    args = data['args']
    cameras = data['cameras']
    poses3d = data['poses3d']
    groundtruth = data['groundtruth']
    frame_list=data['frame_list']
    e = data['errors']
    proj_time=0
    ll_time=0
    ff_time=0

    residuals = {}
    mpi = False
    human36m = False

    if args['dataset_name'] == "mpi":
        mpi = True
    elif args['dataset_name'] == "mpi":
        human36m = True

    detector = 'groundtruth'
    if args['2d_detector'] == "mediapipe":
        detector = 'mediapipe'
        links = links_mediapipe
        skeleton_joints = joints_mediapipe
    elif args['2d_detector'] == "openpose":
        detector = 'openpose'
        links = links_openpose
        skeleton_joints = joint_idxs
        if mpi == True:
            skeleton_joints = joint_idxs_op_mpi
            links = links_op_mpi
    else:
        links = links_mpi_inf_3dhp
        skeleton_joints = joints_mpi_inf_3dhp

    joints_to_use = []
    for _, link in links.items():  # derive al joints to use based on the parent and childs
        joints_to_use.append(link['parent'])
        joints_to_use.append(link['child'])

    joints_to_use = list(set(joints_to_use))  # remove repetitions

    # Projection residuals ------------------------------------------------
    t = TicToc()  # create instance of class
    t.tic()
    iteration = data['status']['num_iterations']

    for camera_key, camera in cameras.items():
        for frame_key in frame_list:
            frame_key=str(frame_key)
            frame=camera['frames'][frame_key]
            for joint_key, joint in frame['joints'].items():

                # get 3d point coordinates
                pts_in_world = np.ndarray((4, 1), dtype=float)
                pts_in_world[0][0] = poses3d[frame_key][joint_key]['X']
                pts_in_world[1][0] = poses3d[frame_key][joint_key]['Y']
                pts_in_world[2][0] = poses3d[frame_key][joint_key]['Z']
                pts_in_world[3][0] = 1

                # if iteration==0  and data['status']['is_iteration']:
                #     print(camera_key, " Frame: " , frame_key, "Joint: " , joint_key, " Initial Guess: X= " + str(pts_in_world[0][0]), ', Y= ', str(pts_in_world[1][0]), ', Z= ', str(pts_in_world[2][0]))

                # Transform to the camera's coordinate frame
                pts_in_sensor = np.dot(camera['extrinsics'], pts_in_world)

                # Project 3D point to camera
                pts_in_image, _ = projectToCamera_faster(camera['intrinsics'], camera['distortion'], camera['width'], camera['height'], pts_in_sensor[0:3, :])

                # Compute distance from 2D detection of joint and 3D->2D projection
                xpix_projected = pts_in_image[0][0]
                ypix_projected = pts_in_image[1][0]
                joint['x_proj'] = pts_in_image[0][0]
                joint['y_proj'] = pts_in_image[1][0]

                if joint['valid']:
                    xpix_detected = joint['x']
                    ypix_detected = joint['y']

                    residual_key = 'projection_sensor_' + camera_key + '_frame_' + frame_key + '_joint_' + joint_key
                    residuals[residual_key] = joint['confidence'] * 2 * math.sqrt(
                        (xpix_detected - xpix_projected) ** 2 + (ypix_detected - ypix_projected) ** 2)

    if args['print_time'] and data['status']['is_iteration']:
        # t.toc('Projections residuals took ')
        proj_time=t.tocvalue()


    # Consecutive frame distance ------------------------------------------------
    t.tic()
    if not args['skip_frame_to_frame_residuals']:
        for initial_frame_key, final_frame_key in zip(list(poses3d.keys())[:-1], list(poses3d.keys())[1:]):
            if not int(final_frame_key) - int(initial_frame_key) == 1:  # frames are not consecutive
                continue

            if int(final_frame_key) not in frame_list:
                continue

            initial_pose = poses3d[initial_frame_key]
            final_pose = poses3d[final_frame_key]

            for joint_key in initial_pose.keys():
                initial_joint = initial_pose[joint_key]
                final_joint = final_pose[joint_key]

                residual_key = 'consecutive_frame_' + initial_frame_key + '_frame_' + final_frame_key + '_joint_' + joint_key
                xi, yi, zi = initial_joint['X'], initial_joint['Y'], initial_joint['Z']
                xf, yf, zf = final_joint['X'], final_joint['Y'], final_joint['Z']
                # residuals[residual_key] = math.exp(pointsDistance(xi, yi, zi, xf, yf, zf))*10

                invalid_joint = False
                for camera_key, camera in cameras.items():
                    initial_frame = camera['frames'][initial_frame_key]
                    final_frame = camera['frames'][final_frame_key]
                    initial_joint = initial_frame['joints'][joint_key]
                    final_joint = final_frame['joints'][joint_key]
                    if not initial_joint['valid'] or not final_joint['valid']:
                        invalid_joint = True
                        break

                if pointsDistance(xi, yi, zi, xf, yf, zf) > 150 or invalid_joint == True:
                    residuals[residual_key] = pointsDistance(xi, yi, zi, xf, yf, zf)
                else:
                    residuals[residual_key] = 0

    if args['print_time'] and data['status']['is_iteration']:
        # t.toc('Consecutive frame residuals took ')
        ff_time = t.tocvalue()

    # Link length residuals  ----------------------------------------------------
    t.tic()


    if not args['skip_link_length_residuals']:
        table_header = ['Iteration #', 'Frame', 'Parent', 'Child', 'Groundtruth', 'Measured', 'Error', 'Std Deviation']
        table = PrettyTable(table_header)

        table_header_gt = ['Frame #']
        for link_key, _ in links.items():
            table_header_gt.append(link_key)

        table_gt = PrettyTable(table_header_gt)

        if iteration == 0:
            for frame_key in frame_list:
                frame_key = str(frame_key)
                avg_row = [frame_key]
                for link_key, link in links.items():
                    parent_gt = groundtruth[frame_key][link['parent']]['pose']
                    Xp, Yp, Zp = parent_gt['x'], parent_gt['y'], parent_gt['z']

                    child_gt = groundtruth[frame_key][link['child']]['pose']
                    Xc, Yc, Zc = child_gt['x'], child_gt['y'], child_gt['z']

                    avg_row.append('%.1f' % pointsDistance(Xp, Yp, Zp, Xc, Yc, Zc))

                table_gt.add_row(avg_row)

        for link_key, link in links.items():
            link_lengths = []
            groundtruth_lengths = []
            for frame_key in frame_list:
                frame_key = str(frame_key)
                frame = poses3d[frame_key]

                # ground truth link length
                parent_gt = groundtruth[frame_key][link['parent']]['pose']
                Xp, Yp, Zp = parent_gt['x'], parent_gt['y'], parent_gt['z']

                child_gt = groundtruth[frame_key][link['child']]['pose']
                Xc, Yc, Zc = child_gt['x'], child_gt['y'], child_gt['z']

                groundtruth_lengths.append(pointsDistance(Xp, Yp, Zp, Xc, Yc, Zc))

                # print("Frame #" + str(frame_key) + '; link '+ str(link_key) + ': ' + str(pointsDistance(Xp, Yp, Zp, Xc, Yc, Zc)))

                # Measured link length
                parent_pose = frame[link['parent']]
                Xp, Yp, Zp = parent_pose['X'], parent_pose['Y'], parent_pose['Z']

                child_pose = frame[link['child']]
                Xc, Yc, Zc = child_pose['X'], child_pose['Y'], child_pose['Z']

                length = pointsDistance(Xp, Yp, Zp, Xc, Yc, Zc)
                link_lengths.append(length)

            link_length_average = average(link_lengths)
            groundtruth_lengths = average(groundtruth_lengths)

            link_lengths_measured = []
            variance = 0
            n_frames = 0
            for frame_key in frame_list:
                frame_key = str(frame_key)
                frame = poses3d[frame_key]

                parent_pose = frame[link['parent']]
                Xp, Yp, Zp = parent_pose['X'], parent_pose['Y'], parent_pose['Z']
                child_pose = frame[link['child']]
                Xc, Yc, Zc = child_pose['X'], child_pose['Y'], child_pose['Z']

                link_length = pointsDistance(Xp, Yp, Zp, Xc, Yc, Zc)
                link_lengths_measured.append(link_length)

                variance += abs(link_length - link_length_average)
                n_frames += 1

            # link_length_residual = average(link_lengths_measured)
            link_length_residual = variance / n_frames

            row = [str(iteration)]
            row.append(Fore.BLUE + Style.BRIGHT + 'AVERAGE ' + Style.RESET_ALL)
            row.append(link['parent'])
            row.append(link['child'])
            row.append('%.1f' % groundtruth_lengths)
            row.append('%.1f' % average(link_lengths_measured))
            row.append('%.1f' % abs(groundtruth_lengths - average(link_lengths_measured)))
            row.append('%.1f' % abs(link_length_residual))
            table.add_row(row)

            residual_key = 'link_length_' + 'joint_' + link['parent'] + '_joint_' + link['child']
            residuals[residual_key] = link_length_residual

    if args['print_time'] and data['status']['is_iteration']:
        # t.toc('Link lenght residuals took ')
        ll_time = t.tocvalue()
        total_time=proj_time+ll_time+ff_time
        print("Projection residuals took " + str(proj_time) +"s and " + str(round(proj_time/total_time*100,2)) +"%")
        print("Link lenght residuals took " + str(ll_time) +"s and " + str(round(ll_time/total_time*100,2))+"%")
        print("Frame to frame residuals took " + str(ff_time) +"s and " + str(round(ff_time/total_time*100,2))+"%")


    if args['debug'] and data['status']['is_iteration'] and not args['skip_link_length_residuals']:
        print(table)

    if iteration == 0 and args['debug'] and not args['skip_link_length_residuals']:
        print(table_gt)
    # exit(0)

    if args['debug'] and data['status']['is_iteration']:

        iteration = data['status']['num_iterations']
        table_header = ['Frame #']
        for joint_key in joints_to_use:
            table_header.append(joint_key)

        # table_header.append('Average')
        table = PrettyTable(table_header)

        e[iteration] = {}
        for frame_key, frame in poses3d.items():
            e[iteration][frame_key] = {}
            avg_row = [frame_key]
            for joint_key in joints_to_use:
                joint_detected = frame[joint_key]

                e[iteration][frame_key][joint_key] = {}

                x_det = joint_detected['X']
                y_det = joint_detected['Y']
                z_det = joint_detected['Z']

                joint_gt = groundtruth[frame_key][joint_key]['pose']

                x_gt = joint_gt['x']
                y_gt = joint_gt['y']
                z_gt = joint_gt['z']

                rmse = abs(math.dist([x_det, y_det, z_det], [x_gt, y_gt, z_gt]))

                e[iteration][frame_key][joint_key]['rmse'] = rmse

                avg_row.append('%.1f' % rmse)
            table.add_row(avg_row)
        print(table)


        # t.toc('Link length residuals took ')

    return residuals
