#!/usr/bin/env python3
"""
Casts an optimization problem using an ATOM dataset file as input. Then calibrates by running the optimization.
"""


# Standard imports
import math
import signal
import sys
import argparse
from statistics import mean, stdev


# Atom imports
from colorama import Fore, Style
from atom_core.dataset_io import loadResultsJSON, filterSensorsFromDataset
from atom_core.utilities import addAveragesBottomRowToTable,createLambdaExpressionsForArgs

# Ros imports
from urdf_parser_py.urdf import URDF


# -------------------------------------------------------------------------------
# --- FUNCTIONS
# -------------------------------------------------------------------------------
def signal_handler(sig, frame):
    print('Stopping optimization (Ctrl+C pressed)')
    sys.exit(0)


# -------------------------------------------------------------------------------
# --- MAIN
# -------------------------------------------------------------------------------
def main():
    # ---------------------------------------
    # --- Parse command line argument
    # ---------------------------------------
    signal.signal(signal.SIGINT, signal_handler)

    ap = argparse.ArgumentParser()
    ap.add_argument("-json", "--json_file", help="Json file containing input dataset.", type=str, required=True)
    ap.add_argument(
        "-ssf", "--sensor_selection_function", default=None, type=str,
        help="A string to be evaluated into a lambda function that receives a sensor name as input and "
        "returns True or False to indicate if the sensor should be loaded (and used in the "
        "optimization). The Syntax is lambda name: f(x), where f(x) is the function in python "
        'language. Example: lambda name: name in ["left_laser", "frontal_camera"] , to load only '
        "sensors left_laser and frontal_camera")

    # Roslaunch adds two arguments (__name and __log) that break our parser. Lets remove those.
    arglist = [x for x in sys.argv[1:] if not x.startswith('__')]
    args = vars(ap.parse_args(args=arglist))
    args = createLambdaExpressionsForArgs(args)  # selection functions are now lambdas

    # ---------------------------------------
    # --- INITIALIZATION Read data from file
    # ---------------------------------------
    # Loads a json file containing the detections. Returned json_file has path resolved by urireader.
    dataset, json_file = loadResultsJSON(args['json_file'])

    # ---------------------------------------
    # --- Filter some collections and / or sensors from the dataset
    # ---------------------------------------
    dataset = filterSensorsFromDataset(dataset, args)  # filter sensors

    print('Loaded dataset containing ' + str(len(dataset['sensors'].keys())) + ' sensors and ' + str(
        len(dataset['collections'].keys())) + ' collections.')

    print('Dataset contains ' + str(len(dataset['sensors'].keys())) +
          ' sensors: ' + str(list(dataset['sensors'].keys())))

    print('Dataset contains ' + str(len(dataset['patterns'].keys())) +
          ' patterns: ' + str(list(dataset['patterns'].keys())))

    # ---------------------------------------
    # --- Define selected collection key.
    # ---------------------------------------
    # For the getters we only need to get one collection because optimized transformations are static, which means they are the same for all collections. Let's select the first key in the dictionary and always get that transformation.
    selected_collection_key = list(dataset["collections"].keys())[0]
    print("Selected collection key is " + str(selected_collection_key))
    # ---------------------------------------
    # --- Count incomplete collections
    # ---------------------------------------
    complete_collections = []
    incomplete_collections = []
    for collection_key, collection in dataset['collections'].items():
        is_complete = True
        for pattern_key, pattern in dataset['patterns'].items():
            for sensor_key, sensor in dataset['sensors'].items():
                if not collection['labels'][pattern_key][sensor_key]['detected']:
                    is_complete = False

        if is_complete:
            complete_collections.append(collection_key)
        else:
            incomplete_collections.append(collection_key)

    print('Complete collections (' + str(len(complete_collections)) + '):' + str(complete_collections))
    print('Incomplete collections (' + str(len(incomplete_collections)) + '):' + str(incomplete_collections))

    # ---------------------------------------
    # --- Count partial detections
    # ---------------------------------------

    for sensor_key, sensor in dataset['sensors'].items():
        complete_detections = {}
        partial_detections = {}

        for pattern_key, pattern in dataset['calibration_config']['calibration_patterns'].items():
            complete_detections[pattern_key] = []
            partial_detections[pattern_key] = []


        if sensor['modality'] == 'rgb':  # and collection['labels'][sensor_key]['detected']:
            for collection_key, collection in dataset['collections'].items():

                is_partial_detection = False
                for pattern_key, pattern in dataset['calibration_config']['calibration_patterns'].items():
                    number_of_corners = int(pattern['dimension']['x']) * int(pattern['dimension']['y'])
                    if not len(collection['labels'][pattern_key][sensor_key]['idxs']) == number_of_corners:
                        is_partial_detection = True
                
                    # DEBUG
                    if collection_key == '006':
                        print(is_partial_detection)

                    if is_partial_detection:
                        partial_detections[pattern_key].append(collection_key)
                    else:
                        complete_detections[pattern_key].append(collection_key)

            for pattern_key, pattern in dataset['calibration_config']['calibration_patterns'].items():

                print('Sensor ' + sensor_key + ' has ' + str(len(complete_detections[pattern_key])) + ' complete detections of pattern ' + pattern_key + ': ' + str(
                complete_detections[pattern_key]))
                print('Sensor ' + sensor_key + ' has ' + str(len(partial_detections[pattern_key])) + ' partial detections of pattern ' + pattern_key + ': ' + str(
                partial_detections[pattern_key]))

        else:
            print('Sensor ' + sensor_key + ' is not a camera. All detections are complete.')

    # ---------------------------------------
    # --- Collection by collections
    # ---------------------------------------
   # number_of_corners = int(dataset['calibration_config']['calibration_pattern']['dimension']['x']) * \
    #                     int(dataset['calibration_config']['calibration_pattern']['dimension']['y'])
    #
    # for collection_key, collection in dataset['collections'].items():
    #     print(Fore.BLUE + 'Collection ' + collection_key + Style.RESET_ALL)
    #
    #     msg = '['
    #     count = 0
    #     for sensor_key, sensor in dataset['sensors'].items():
    #         if count > 0:
    #             msg += ', '
    #         if not collection['labels'][sensor_key]['detected']:
    #             msg += Fore.RED + sensor_key + Style.RESET_ALL
    #         else:
    #             msg += Fore.GREEN + sensor_key + Style.RESET_ALL
    #         count += 1
    #     msg += ']'
    #
    #     print('Pattern detected in sensors (' + Fore.GREEN + 'yes' + Style.RESET_ALL + ',' + Fore.RED + 'no' + Style.RESET_ALL + '): ' + msg)
    #
    #
    #     msg = '['
    #     count = 0
    #     for sensor_key, sensor in dataset['sensors'].items():
    #         if count > 0:
    #             msg += ', '
    #         if sensor['msg_type'] == 'Image' and collection['labels'][sensor_key]['detected']:
    #             if not len(collection['labels'][sensor_key]['idxs']) == number_of_corners:
    #                 msg += Fore.GREEN + sensor_key + Style.RESET_ALL
    #             else:
    #                 msg += Fore.RED + sensor_key + Style.RESET_ALL
    #         else:
    #             msg += Fore.BLACK + sensor_key + Style.RESET_ALL
    #
    #         count +=1
    #     msg += ']'
    #
    #     print('Partial detection (' + Fore.GREEN + 'yes' + Style.RESET_ALL + ',' + Fore.RED + 'no' + Style.RESET_ALL + ', na): ' + msg)

    # ---------------------------------------
    # --- Draw stylized table
    # ---------------------------------------
    from prettytable import PrettyTable
    sensor_keys = list(dataset['sensors'].keys())

    for pattern_key, pattern in dataset['calibration_config']['calibration_patterns'].items():

        # Number of corners var needs to be redefined for each pattern otherwise it holds the value of the last one (issue #892)
        number_of_corners = int(pattern['dimension']['x']) * int(pattern['dimension']['y'])

        print('\nAnalysis for pattern ' + Style.BRIGHT + Fore.BLUE + pattern_key + Style.RESET_ALL)
        table_header = ['Collection', 'is complete']

        table_header.extend(sensor_keys)
        table = PrettyTable(table_header)

        for collection_key, collection in dataset['collections'].items():
            row = [collection_key, '---']
            is_complete = True
            for sensor_key, sensor in dataset['sensors'].items():
                # print(sensor_key)
                # if sensor['msg_type'] == 'Image' and collection['labels'][sensor_key]['detected']:
                #     if not len(collection['labels'][sensor_key]['idxs']) == number_of_corners:
                #         row.append(Fore.GREEN + 'detected' + Style.RESET_ALL)
                #     else:
                #         row.append(Fore.BLUE + 'partial' + Style.RESET_ALL)
                #
                # elif collection['labels'][sensor_key]['detected']:
                #     row.append(Fore.GREEN + 'detected' + Style.RESET_ALL)
                # else:
                #     row.append(Fore.RED + 'not detected' + Style.RESET_ALL)
                #     is_complete = False
                if sensor['modality'] == 'rgb':
                    if not collection['labels'][pattern_key][sensor_key]['detected']:
                        row.append(Fore.RED + 'not detected' + Style.RESET_ALL)
                        is_complete = False
                    else:
                        if len(collection['labels'][pattern_key][sensor_key]['idxs']) == number_of_corners:
                            row.append(Fore.GREEN + 'detected' + Style.RESET_ALL)
                        else:
                            row.append(Fore.BLUE + 'partial (' + str(len(collection['labels'][pattern_key][sensor_key]['idxs'])) + '/' + str(number_of_corners) + ' corners detected)' + Style.RESET_ALL)

                else:
                    if not collection['labels'][pattern_key][sensor_key]['detected']:
                        row.append(Fore.RED + 'not detected' + Style.RESET_ALL)
                        is_complete = False
                    else:
                        row.append(Fore.GREEN + 'detected' + Style.RESET_ALL)

                # print(row)
                #
                # elif collection['labels'][sensor_key]['detected']:
                #     row.append(Fore.GREEN + 'detected' + Style.RESET_ALL)
                # else:

            if is_complete:
                row[1] = Fore.GREEN + 'yes' + Style.RESET_ALL
            else:
                row[1] = Fore.RED + 'no' + Style.RESET_ALL

            table.add_row(row)

        table.align = 'c'
        # table.align[Back.LIGHTWHITE_EX + "Player"] = 'l'
        # table.align['Team'] = 'l'

        print(Style.BRIGHT + '\nCollections' + Style.RESET_ALL)
        print(table)

    # ---------------------------------------
    # --- Draw table with analysis of joints
    # ---------------------------------------
    if dataset['calibration_config']['joints'] is not None:
        header = ['Joint', 'Mean', 'Std', 'Min', 'Max', 'Range [deg]']
        table = PrettyTable(header)

        for joint_key, joint in dataset['calibration_config']['joints'].items():
            row = [joint_key]
            positions = []
            units = 'meters' if dataset['collections'][selected_collection_key]['joints'][joint_key]['joint_type'] == 'prismatic' else 'radians'

            # Get all positions in the dataset
            for collection_key, collection in dataset['collections'].items():
                positions.append(collection['joints'][joint_key]['position'])

            row.extend(['{:.5f}'.format(mean(positions)), '{:.5f}'.format(stdev(positions)),
                        '{:.5f}'.format(max(positions)), '{:.5f}'.format(min(positions)),
                        '{:.5f}'.format((max(positions) - min(positions))*180.0/math.pi),])
            table.add_row(row)

        # Add bottom row with averages
        table = addAveragesBottomRowToTable(table, header)

        print(Style.BRIGHT + '\nJoints' + Style.RESET_ALL)
        print(table)


if __name__ == "__main__":
    main()
