#!/usr/bin/env python3

# System and standard imports
import argparse
import os
import os.path
import sys
from functools import partial

# ros imports
from atom_core.utilities import createLambdaExpressionsForArgs
import rospy
from atom_core.naming import generateLabeledTopic
from geometry_msgs.msg import PointStamped
from sensor_msgs.msg import PointCloud2

# 3rd-party imports
from colorama import Back, Fore, Style
from pynput import keyboard

# Atom imports
from atom_calibration.odometry_inspector.depth_manual_labeling import clickedPointsCallback, clickedPointsReset
from atom_calibration.odometry_inspector.lidar3d_manual_labeling import *
from atom_calibration.odometry_inspector.visualization import setupVisualization, visualizationFunction
from atom_core.dataset_io import (filterCollectionsFromDataset,
                                  filterSensorsFromDataset, loadResultsJSON, saveAtomDataset)

from atom_calibration.odometry_inspector.interactive_first_guess import InteractiveFirstGuess

# own packages


# global variables  ... are forbidden

# -------------------------------------------------------------------------------
# --- FUNCTIONS
# -------------------------------------------------------------------------------


def printHelp():

    s = 'Drag with the left button to select objects in the 3D scene.\n' \
        ' Hold the Alt key to change viewpoint as in the Move tool.'\
        ' Holding the Shift key will allow adding to the current selection.\n'\
        ' Holding the Ctrl key will allow subtraction from the current selection.\n'\
        ' The following keys are also available:\n'\
        ' p - Publish selected points to /rviz_selected_points\n'\
        ' b - Publish selected points to /rviz_selected_border_points\n'\
        ' r - Publish selected points to /rviz_selected_remove_points\n'\
        ' c - Publish selected points to /rviz_selected_clear_all_points\n'\
        ' d - Switch between detect (find pattern inside user defined polygon) and delete (delete boundary points inside polygon) modes for depth labeling\n'\
        ' s - Save dataset\n'\
        ' q - Save dataset and quit \n'

    print(s)


def keyPressedCallback(key, selection, dataset, args, depth_mode, output_file):

    if key is None:
        return
    # TODO #402 does not work if some of the collections in the middles, e.g. '1' are removed. An easy test is to run a -csf "lambda x: x in ['0', '2']"

    # print("collection_key = " + str(selection['collection_key']))
    # print("previous_collection_key = " + str(selection['previous_collection_key']))

    # Shortcut variables
    idx_collection = list(dataset['collections'].keys()).index(selection['collection_key'])
    idx_max_collection = len(dataset['collections'].keys()) - 1

    # Convert from type to string to keep uniform.

    try:
        key_char = key.char
        # print(key_char)
        if key_char == 's':  # Saves dataset.
            saveAtomDataset(output_file, dataset, freeze_dataset=True)
            print('A new dataset was saved in ' + output_file)
        elif key_char == 'q':  # Saves dataset and quits.
            saveAtomDataset(output_file, dataset, freeze_dataset=True)
            print('A new dataset was saved in ' + output_file)
            print('Exiting ...')
            selection['exit'] = True
        elif key_char == 'h':  # Prints help.
            printHelp()
        elif key_char == 'd':  # Prints help.

            if depth_mode['mode'] == 'detect':
                depth_mode['mode'] = 'delete'
            else:
                depth_mode['mode'] = 'detect'

            print('Changed depth mode to ' + depth_mode['mode'])

    except AttributeError:
        if key == keyboard.Key.right:  # Save and move to collection + 1.
            if idx_collection < idx_max_collection:
                selection['previous_collection_key'] = selection['collection_key']
                selection['collection_key'] = list(dataset['collections'].keys())[idx_collection+1]
                print('Changed selected_collection_key to ' + str(selection['collection_key']))
                saveAtomDataset(output_file, dataset, freeze_dataset=True)

                # More intuitive if detect mode is True at the beginning of a new collection
                depth_mode['mode'] = 'detect'
                print('Changed depth mode to ' + depth_mode['mode'])
            else:
                print(Fore.RED + 'This is the last collection!!' + Fore.RESET)
        elif key == keyboard.Key.left:  # Save and move to collection - 1.
            if idx_collection > 0:
                selection['previous_collection_key'] = selection['collection_key']
                selection['collection_key'] = list(dataset['collections'].keys())[idx_collection-1]
                print('Changed selected_collection_key to ' + str(selection['collection_key']))
                saveAtomDataset(output_file, dataset, freeze_dataset=True)

                # More intuitive if detect mode is True at the beginning of a new collection
                depth_mode['mode'] = 'detect'
                print('Changed depth mode to ' + depth_mode['mode'])

            else:
                print(Fore.RED + 'This is the first collection!!' + Fore.RESET)


# -------------------------------------------------------------------------------
# --- MAIN
# -------------------------------------------------------------------------------


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("-json", "--json_file",
                    help="Json file containing input dataset.", type=str, required=True)
    ap.add_argument(
        "-ow", "--overwrite",
        help="Overwrites the data_corrected.json without asking for permission.",
        action='store_true')
    ap.add_argument(
        "-rnb", "--remove_nan_border",
        help="Option for the labeling of depth images. It Will run a detection of nan values in the image, searching for the actual area of the image which is used. Then, border detection will use this estimated area...",
        action='store_true')
    ap.add_argument(
        "-rc", "--reference_collection",
        help="", required=True)
    ap.add_argument(
        "-rs", "--reference_sensor",
        help="", required=True)
    # ap.add_argument(
    # "-2d", "--2d_mode", required=True,
    # help="Interactive marker only moves in 2D. Useful for odometry transformations",)
    ap.add_argument("-op", "--odometry_parent", help="", required=True)
    ap.add_argument("-oc", "--odometry_child", help="", required=True)
    ap.add_argument("-s", "--marker_scale", type=float, default=0.9,
                    help='Scale of the interactive markers.')
    ap.add_argument(
        "-csf", "--collection_selection_function", default=None, type=str,
        help="A string to be evaluated into a lambda function that receives a collection name as input and "
        "returns True or False to indicate if the collection should be loaded (and used in the "
        "optimization). The Syntax is lambda name: f(x), where f(x) is the function in python "
        "language. Example: lambda name: int(name) > 5 , to load only collections 6, 7, and onward.")

    # Roslaunch adds two arguments (__name and __log) that break our parser. Lets remove those.
    arglist = [x for x in sys.argv[1:] if not x.startswith('__')]
    # these args have the selection functions as strings
    args_original = vars(ap.parse_args(args=arglist))
    args = createLambdaExpressionsForArgs(args_original)  # selection functions are now lambdas

    # Arguments which are not options for odometry_inspector
    args['ros_visualization'] = True
    args['show_images'] = True
    args['use_incomplete_collections'] = True
    args['remove_partial_detections'] = False
    args['collection_selection_function'] = None

    # ---------------------------------------
    # --- INITIALIZATION Read data from file
    # ---------------------------------------
    # Loads a json file containing the detections. Returned json_file has path resolved by urireader.
    dataset, json_file = loadResultsJSON(args['json_file'], args['collection_selection_function'])

    # ---------------------------------------
    # --- Filter some collections and / or sensors from the dataset
    # ---------------------------------------
    # dataset = filterCollectionsFromDataset(dataset, args)  # During dataset review there is no need to filter out collections

    output_file = os.path.join(os.path.dirname(args['json_file']), 'dataset_corrected.json')
    if os.path.exists(output_file) and args['json_file'] != output_file and not args['overwrite']:
        ans = input('The file dataset_corrected.json already exists.'
                    ' Do you want to overwrite? (Y/n)')
        if ans.lower() == 'n':
            sys.exit(0)

    dataset = filterSensorsFromDataset(dataset, args)  # filter sensors

    print('Loaded dataset containing ' + str(len(dataset['sensors'].keys())) + ' sensors and ' +
          str(len(dataset['collections'].keys())) + ' collections.')

    depth_mode = {'mode': 'detect'}

    # ---------------------------------------
    # --- Define selection
    # ---------------------------------------
    # Lets start with the first key on the collections dictionary.
    # Data structure used to save the state of navigation throughout the collections in the dataset.
    selection = {'collection_key': list(dataset['collections'].keys())[
        0], 'previous_collection_key': None, 'exit': False}

    print("Configuring visualization ... ")
    graphics = setupVisualization(dataset, args, selection['collection_key'])

    # ---------------------------------------
    # --- lidar3d modality subscribers
    # ---------------------------------------

    # Define subscriber to receive the selected points
    rospy.Subscriber("/rviz/selected_points", PointCloud2,
                     partial(selectedPointsCallback, selection=selection, dataset=dataset))

    rospy.Subscriber("/rviz/selected_border_points", PointCloud2,
                     partial(selectedPointsBorderCallback, selection=selection, dataset=dataset))

    rospy.Subscriber("/rviz/selected_remove_points", PointCloud2,
                     partial(selectedPointsRemoveCallback, selection=selection, dataset=dataset))

    rospy.Subscriber("/rviz/selected_clear_all_points", PointCloud2,
                     partial(selectedPointsClearAllCallback, selection=selection, dataset=dataset))

    # ---------------------------------------
    # --- Depth modality subscribers
    # ---------------------------------------

    # Subscriber for the image_click plugin
    # this stores the pixel coordinates of clicked points per collection, and per sensor key
    clicked_points = {collection_key: {} for collection_key in dataset['collections']}

    # Initialize clicked points for all collections and sensors
    for collection_key in dataset['collections']:
        for sensor_key, sensor in dataset['sensors'].items():
            if sensor['modality'] == 'depth':
                clickedPointsReset(clicked_points, collection_key, sensor_key)

    # Create a subscriber for each depth sensor
    for sensor_key, sensor in dataset['sensors'].items():
        if sensor['modality'] == 'depth':

            points_topic = generateLabeledTopic(
                dataset['sensors'][sensor_key]['topic'],
                type='2d') + '/mouse_click'
            rospy.Subscriber(points_topic, PointStamped,
                             partial(
                                 clickedPointsCallback, clicked_points=clicked_points,
                                 dataset=dataset, sensor_key=sensor_key, selection=selection,
                                 depth_mode=depth_mode, args=args))

    # ---------------------------------------
    # --- RGB modality subscribers
    # ---------------------------------------
    # TODO #394 To implement in the future.

    # ---------------------------------------
    # --- Define callback to navigate through collections
    # ---------------------------------------
    listener = keyboard.Listener(
        on_press=partial(
            keyPressedCallback, selection=selection, dataset=dataset, args=args,
            depth_mode=depth_mode, output_file=output_file))
    listener.start()

    # ---------------------------------------
    # Setting up interactive marker
    # ---------------------------------------
    first_guess = InteractiveFirstGuess(args, dataset, selection)
    first_guess.init()

    # ---------------------------------------
    # --- Loop while displaying selected collection
    # ---------------------------------------
    print('\nPress ' + Fore.BLUE + '"h"' + Style.RESET_ALL + ' for instructions.\n')

    rate = rospy.Rate(10)  # in hertz.
    tic = rospy.Time.now()
    current_collection = 'this will not be a collection key'
    while not rospy.is_shutdown() and not selection['exit']:
        models = {'dataset': dataset, 'args': args, 'graphics': graphics}
        if selection['collection_key'] != current_collection:
            first_guess.additional_tf.updateAll()
            first_guess.additional_tf.updateMarkers()
            current_collection = selection['collection_key']

        visualizationFunction(models=models, selection=selection, clicked_points=clicked_points)
        rate.sleep()

        # if (rospy.Time.now() - tic).to_sec() > 3:
        # tic = rospy.Time.now()


if __name__ == "__main__":
    main()
