#!/usr/bin/env python3

# System and standard imports
import argparse
import os
import os.path
import sys
from functools import partial

# ros imports
import rospy
from atom_core.naming import generateLabeledTopic
from geometry_msgs.msg import PointStamped
from sensor_msgs.msg import PointCloud2

# 3rd-party imports
from colorama import Back, Fore, Style
from pynput import keyboard

# Atom imports
from atom_calibration.dataset_playback.depth_manual_labeling import clickedPointsCallback, clickedPointsReset
from atom_calibration.dataset_playback.lidar3d_manual_labeling import *
from atom_calibration.dataset_playback.visualization import setupVisualization, visualizationFunction
from atom_core.dataset_io import (filterCollectionsFromDataset, filterSensorsFromDataset, loadResultsJSON,
                                  saveAtomDataset)

# own packages


# global variables  ... are forbidden

# -------------------------------------------------------------------------------
# --- FUNCTIONS
# -------------------------------------------------------------------------------


def printHelp():

    s = 'Drag with the left button to select objects in the 3D scene.\n' \
        ' Hold the Alt key to change viewpoint as in the Move tool.'\
        ' Holding the Shift key will allow adding to the current selection.\n'\
        ' Holding the Ctrl key will allow subtraction from the current selection.\n'\
        ' The following keys are also available:\n'\
        ' p - Publish selected points to /rviz_selected_points\n'\
        ' b - Publish selected points to /rviz_selected_border_points\n'\
        ' r - Publish selected points to /rviz_selected_remove_points\n'\
        ' c - Publish selected points to /rviz_selected_clear_all_points\n'\
        ' d - Switch between detect (find pattern inside user defined polygon) and delete (delete boundary points inside polygon) modes for depth labeling\n'\
        ' s - Save dataset\n'\
        ' q - Save dataset and quit \n'

    print(s)


def keyPressedCallback(key, selection, dataset, args, depth_mode, output_file):

    if key is None:
        return
    # TODO #402 does not work if some of the collections in the middles, e.g. '1' are removed. An easy test is to run a -csf "lambda x: x in ['0', '2']"

    # print("collection_key = " + str(selection['collection_key']))
    # print("previous_collection_key = " + str(selection['previous_collection_key']))

    # Shortcut variables
    idx_collection = list(dataset['collections'].keys()).index(selection['collection_key'])
    idx_max_collection = len(dataset['collections'].keys()) - 1

    # Convert from type to string to keep uniform.

    try:
        key_char = key.char
        # print(key_char)
        if key_char == 's':  # Saves dataset.
            saveAtomDataset(output_file, dataset, freeze_dataset=True)
            print('A new dataset was saved in ' + output_file)
        elif key_char == 'q':  # Saves dataset and quits.
            saveAtomDataset(output_file, dataset, freeze_dataset=True)
            print('A new dataset was saved in ' + output_file)
            print('Exiting ...')
            selection['exit'] = True
        elif key_char == 'h':  # Prints help.
            printHelp()
        elif key_char == 'd':  # Prints help.

            if depth_mode['mode'] == 'detect':
                depth_mode['mode'] = 'delete'
            else:
                depth_mode['mode'] = 'detect'

            print('Changed depth mode to ' + depth_mode['mode'])

    except AttributeError:
        if key == keyboard.Key.right:  # Save and move to collection + 1.
            if idx_collection < idx_max_collection:
                selection['previous_collection_key'] = selection['collection_key']
                selection['collection_key'] = list(dataset['collections'].keys())[idx_collection+1]
                print('Changed selected_collection_key to ' + str(selection['collection_key']))
                saveAtomDataset(output_file, dataset, freeze_dataset=True)

                # More intuitive if detect mode is True at the beginning of a new collection
                depth_mode['mode'] = 'detect'
                print('Changed depth mode to ' + depth_mode['mode'])
            else:
                print(Fore.RED + 'This is the last collection!!' + Fore.RESET)
        elif key == keyboard.Key.left:  # Save and move to collection - 1.
            if idx_collection > 0:
                selection['previous_collection_key'] = selection['collection_key']
                selection['collection_key'] = list(dataset['collections'].keys())[idx_collection-1]
                print('Changed selected_collection_key to ' + str(selection['collection_key']))
                saveAtomDataset(output_file, dataset, freeze_dataset=True)

                # More intuitive if detect mode is True at the beginning of a new collection
                depth_mode['mode'] = 'detect'
                print('Changed depth mode to ' + depth_mode['mode'])

            else:
                print(Fore.RED + 'This is the first collection!!' + Fore.RESET)


# -------------------------------------------------------------------------------
# --- MAIN
# -------------------------------------------------------------------------------


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("-json", "--json_file",
                    help="Json file containing input dataset.", type=str, required=True)
    ap.add_argument("-ow", "--overwrite",
                    help="Overwrites the data_corrected.json without asking for permission.", action='store_true')
    ap.add_argument("-rnb", "--remove_nan_border",
                    help="Option for the labeling of depth images. It Will run a detection of nan values in the image, searching for the actual area of the image which is used. Then, border detection will use this estimated area...", action='store_true')

    # Roslaunch adds two arguments (__name and __log) that break our parser. Lets remove those.
    arglist = [x for x in sys.argv[1:] if not x.startswith('__')]
    args = vars(ap.parse_args(args=arglist))

    # Arguments which are not options for dataset_playback
    args['ros_visualization'] = True
    args['show_images'] = True
    args['use_incomplete_collections'] = True
    args['remove_partial_detections'] = False
    args['collection_selection_function'] = None

    # ---------------------------------------
    # --- INITIALIZATION Read data from file
    # ---------------------------------------
    # Loads a json file containing the detections. Returned json_file has path resolved by urireader.
    dataset, json_file = loadResultsJSON(args['json_file'], args['collection_selection_function'])

    # ---------------------------------------
    # --- Filter some collections and / or sensors from the dataset
    # ---------------------------------------
    # dataset = filterCollectionsFromDataset(dataset, args)  # During dataset review there is no need to filter out collections

    output_file = os.path.join(os.path.dirname(args['json_file']), 'dataset_corrected.json')
    if os.path.exists(output_file) and args['json_file'] != output_file and not args['overwrite']:
        ans = input('The file dataset_corrected.json already exists.'
                    ' Do you want to overwrite? (Y/n)')
        if ans.lower() == 'n':
            sys.exit(0)

    dataset = filterSensorsFromDataset(dataset, args)  # filter sensors

    print('Loaded dataset containing ' + str(len(dataset['sensors'].keys())) + ' sensors and ' + str(
        len(dataset['collections'].keys())) + ' collections.')

    depth_mode = {'mode': 'detect'}

    # ---------------------------------------
    # --- Define selection
    # ---------------------------------------
    # Lets start with the first key on the collections dictionary.
    # Data structure used to save the state of navigation throughout the collections in the dataset.
    selection = {'collection_key': list(dataset['collections'].keys())[0], 'previous_collection_key': None,
                 'exit': False}

    print("Configuring visualization ... ")
    graphics = setupVisualization(dataset, args, selection['collection_key'])

    # ---------------------------------------
    # --- lidar3d modality subscribers
    # ---------------------------------------

    # Define subscriber to receive the selected points
    rospy.Subscriber("/rviz/selected_points", PointCloud2,
                     partial(selectedPointsCallback, selection=selection, dataset=dataset))

    rospy.Subscriber("/rviz/selected_border_points", PointCloud2,
                     partial(selectedPointsBorderCallback, selection=selection, dataset=dataset))

    rospy.Subscriber("/rviz/selected_remove_points", PointCloud2,
                     partial(selectedPointsRemoveCallback, selection=selection, dataset=dataset))

    rospy.Subscriber("/rviz/selected_clear_all_points", PointCloud2,
                     partial(selectedPointsClearAllCallback, selection=selection, dataset=dataset))

    # ---------------------------------------
    # --- Depth modality subscribers
    # ---------------------------------------

    # Subscriber for the image_click plugin
    # this stores the pixel coordinates of clicked points per collection, and per sensor key
    clicked_points = {collection_key: {} for collection_key in dataset['collections']}

    # Initialize clicked points for all collections and sensors
    for collection_key in dataset['collections']:
        for sensor_key, sensor in dataset['sensors'].items():
            if sensor['modality'] == 'depth':
                clickedPointsReset(clicked_points, collection_key, sensor_key)

    # Create a subscriber for each depth sensor
    for sensor_key, sensor in dataset['sensors'].items():
        if sensor['modality'] == 'depth':

            points_topic = generateLabeledTopic(dataset['sensors'][sensor_key]['topic'], type='2d') + '/mouse_click'
            rospy.Subscriber(points_topic, PointStamped,
                             partial(clickedPointsCallback, clicked_points=clicked_points,
                                     dataset=dataset, sensor_key=sensor_key, selection=selection, depth_mode=depth_mode, args=args))

    # ---------------------------------------
    # --- RGB modality subscribers
    # ---------------------------------------
    # TODO #394 To implement in the future.

    # ---------------------------------------
    # --- Define callback to navigate through collections
    # ---------------------------------------
    listener = keyboard.Listener(on_press=partial(keyPressedCallback, selection=selection, dataset=dataset, args=args,
                                                  depth_mode=depth_mode, output_file=output_file))
    listener.start()

    # ---------------------------------------
    # --- Loop while displaying selected collection
    # ---------------------------------------
    print('\nPress ' + Fore.BLUE + '"h"' + Style.RESET_ALL + ' for instructions.\n')

    rate = rospy.Rate(10)  # in hertz.
    tic = rospy.Time.now()
    while not rospy.is_shutdown() and not selection['exit']:
        models = {'dataset': dataset, 'args': args, 'graphics': graphics}
        visualizationFunction(models=models, selection=selection, clicked_points=clicked_points)
        rate.sleep()

        # if (rospy.Time.now() - tic).to_sec() > 3:
        # tic = rospy.Time.now()


if __name__ == "__main__":
    main()
