#!/usr/bin/env python

# stdlib
import sys
import argparse
import time

import numpy as np

# 3rd-party
import rospkg
import rospy
from colorama import Fore, Style
from matplotlib import cm
from interactive_markers.interactive_marker_server import InteractiveMarkerServer
from interactive_markers.menu_handler import MenuHandler
from visualization_msgs.msg import InteractiveMarker, InteractiveMarkerControl, Marker
from urdf_parser_py.urdf import URDF
from urdf_parser_py.urdf import Pose as URDFPose

# local packages
from atom_core.ros_utils import filterLaunchArguments
from atom_core.config_io import loadConfig
from atom_calibration.initial_estimate.sensor import Sensor
from atom_calibration.initial_estimate.additional_tf import Additional_tf
from tf.listener import TransformListener


class InteractiveFirstGuess(object):

    def __init__(self, args):

        self.args = args  # command line arguments
        self.config = None  # calibration config
        self.sensors = []  # sensors
        self.additional_tfs = []  # additional_tfs
        self.urdf = None  # robot URDF
        self.server = None  # interactive markers server
        self.menu = None

    def init(self):
        # The parameter /robot_description must exist
        if not rospy.has_param('/robot_description'):
            rospy.logerr("Parameter '/robot_description' must exist to continue")
            sys.exit(1)


        self.listener = TransformListener()
        # Load the urdf from /robot_description
        self.urdf = URDF.from_parameter_server()

        self.config = loadConfig(self.args['calibration_file'])
        if self.config is None:
            sys.exit(1)  # loadJSON should tell you why.

        # ok = validateLinks(self.config.world_link, self.config.sensors, self.urdf)
        # if not ok:
        #     sys.exit(1)

        print('Number of sensors: ' + str(len(self.config['sensors'])))

        if self.checkAdditionalTfs():
            print('Number of additional_tfs: ' +
                str(len(self.config['additional_tfs'])))

        # Init interaction
        self.server = InteractiveMarkerServer('set_initial_estimate')
        self.menu = MenuHandler()

        self.createInteractiveMarker(self.config['world_link'])

        if self.checkAdditionalTfs():
            # sensors and additional tfs to be calibrated
            self.menu.insert("Save sensors and additional transformations configuration",
                            callback=self.onSaveInitialEstimate)
            self.menu.insert(
                    "Reset all sensors and additional transformations to initial configuration", callback=self.onResetAll)
                
        else:
            # only sensors to be calibrated
            self.menu.insert("Save sensors configuration",
                callback=self.onSaveInitialEstimate)
            self.menu.insert(
                "Reset all sensors to initial configuration", callback=self.onResetAll)

        self.menu.reApply(self.server)

        # Waiting for the loading of the bagfile
        time.sleep(2)

        # Create a colormap so that we have one color per sensor
        self.cm_sensors = cm.Pastel2(np.linspace(0, 1, len(self.config['sensors'].keys())))

        # Create a colormap so that we have one color per additional_tfs
        if self.checkAdditionalTfs():
            self.cm_additional_tf = cm.Pastel2(np.linspace(
                0, 1, len(self.config['additional_tfs'].keys())))

        # For each node generate an interactive marker.
        sensor_idx = 0
        for name, sensor in self.config['sensors'].items():
            print(Fore.BLUE + '\nSensor name is ' + name + Style.RESET_ALL)
            params = {
                "frame_world": self.config['world_link'],
                "frame_opt_parent": sensor['parent_link'],
                "frame_opt_child": sensor['child_link'],
                "frame_sensor": sensor['link'],
                "sensor_topic": sensor['topic_name'],
                "sensor_modality": sensor['modality'],
                "sensor_color": tuple(self.cm_sensors[sensor_idx, :]),
                "marker_scale": self.args['marker_scale'],
                "listener": self.listener}

            if self.config['anchored_sensor'] == name:
                print(Fore.YELLOW + 'Warning: sensor ' + name + ' is anchored. Will not create corresponding interactive marker.' + Style.RESET_ALL)
            else:
                # Append to the list of sensors
                self.sensors.append(Sensor(name, self.server, self.menu, **params))
                sensor_idx += 1
                print('... done')

        # For each node generate an interactive marker.
        if self.checkAdditionalTfs():
            additional_tf_idx = 0
            for name, additional_tf in self.config['additional_tfs'].items():
                print(Fore.BLUE + '\nAdditional_tfs name is ' +
                    name + Style.RESET_ALL)
                params = {
                    "frame_world": self.config['world_link'],
                    "frame_opt_parent": additional_tf['parent_link'],
                    "frame_opt_child": additional_tf['child_link'],
                    "frame_additional_tf": additional_tf['child_link'],
                    "additional_tf_color": tuple(self.cm_additional_tf[additional_tf_idx, :]),
                    "marker_scale": self.args['marker_scale'],
                    "listener": self.listener}
                # Append to the list of additional_tfs
                self.additional_tfs.append(Additional_tf(name, self.server, self.menu, **params))
                additional_tf_idx += 1
                print('... done')

        self.server.applyChanges()


    def onSaveInitialEstimate(self, feedback):
        print('onSaveInitialEstimate')
        for sensor in self.sensors:
            # find corresponding joint for this sensor
            for joint in self.urdf.joints:
                if sensor.opt_child_link == joint.child and sensor.opt_parent_link == joint.parent:
                    trans = sensor.optT.getTranslation()
                    euler = sensor.optT.getEulerAngles()

                    # if origin is None create a new URDFPose.
                    # See https://github.com/lardemua/atom/issues/559
                    if joint.origin is None: # 
                        joint.origin = URDFPose()

                    joint.origin.xyz = list(trans)
                    joint.origin.rpy = list(euler)

        # Write the urdf file with atom_calibration's
        # source path as base directory.
        rospack = rospkg.RosPack()
        # outfile = rospack.get_path('atom_calibration') + os.path.abspath('/' + self.args['filename'])
        outfile = self.args['filename']
        with open(outfile, 'w') as out:
            print("Writing first guess urdf to '{}'".format(outfile))
            out.write(self.urdf.to_xml_string())

        if self.checkAdditionalTfs():
            for additional_tfs in self.additional_tfs:
                # find corresponding joint for this additional_tfs
                for joint in self.urdf.joints:
                    if additional_tfs.opt_child_link == joint.child and additional_tfs.opt_parent_link == joint.parent:
                        trans = additional_tfs.optT.getTranslation()
                        euler = additional_tfs.optT.getEulerAngles()
                        joint.origin.xyz = list(trans)
                        joint.origin.rpy = list(euler)

            # Write the urdf file with atom_calibration's
            # source path as base directory.
            rospack = rospkg.RosPack()
            # outfile = rospack.get_path('atom_calibration') + os.path.abspath('/' + self.args['filename'])
            outfile = self.args['filename']
            with open(outfile, 'w') as out:
                print("Writing first guess urdf to '{}'".format(outfile))
                out.write(self.urdf.to_xml_string())

    def onResetAll(self, feedback):
        for sensor in self.sensors:
            sensor.resetToInitalPose()

        if self.checkAdditionalTfs():
            for additional_tfs in self.additional_tfs:
                additional_tfs.resetToInitalPose()

    def createInteractiveMarker(self, world_link):
        marker = InteractiveMarker()
        marker.header.frame_id = world_link
        trans = (1, 0, 1)
        marker.pose.position.x = trans[0]
        marker.pose.position.y = trans[1]
        marker.pose.position.z = trans[2]
        quat = (0, 0, 0, 1)
        marker.pose.orientation.x = quat[0]
        marker.pose.orientation.y = quat[1]
        marker.pose.orientation.z = quat[2]
        marker.pose.orientation.w = quat[3]
        marker.scale = 0.2

        marker.name = 'menu'
        marker.description = 'menu'

        # insert a box
        control = InteractiveMarkerControl()
        control.always_visible = True

        marker_box = Marker()
        marker_box.type = Marker.SPHERE
        marker_box.scale.x = marker.scale * 0.7
        marker_box.scale.y = marker.scale * 0.7
        marker_box.scale.z = marker.scale * 0.7
        marker_box.color.r = 0
        marker_box.color.g = 1
        marker_box.color.b = 0
        marker_box.color.a = 0.2

        control.markers.append(marker_box)
        marker.controls.append(control)

        marker.controls[0].interaction_mode = InteractiveMarkerControl.MOVE_3D

        control = InteractiveMarkerControl()
        control.orientation.w = 1
        control.orientation.x = 1
        control.orientation.y = 0
        control.orientation.z = 0
        control.name = "move_x"
        control.interaction_mode = InteractiveMarkerControl.MOVE_AXIS
        control.orientation_mode = InteractiveMarkerControl.FIXED
        marker.controls.append(control)

        control = InteractiveMarkerControl()
        control.orientation.w = 1
        control.orientation.x = 0
        control.orientation.y = 1
        control.orientation.z = 0
        control.name = "move_z"
        control.interaction_mode = InteractiveMarkerControl.MOVE_AXIS
        control.orientation_mode = InteractiveMarkerControl.FIXED
        marker.controls.append(control)

        control = InteractiveMarkerControl()
        control.orientation.w = 1
        control.orientation.x = 0
        control.orientation.y = 0
        control.orientation.z = 1
        control.name = "move_y"
        control.interaction_mode = InteractiveMarkerControl.MOVE_AXIS
        control.orientation_mode = InteractiveMarkerControl.FIXED
        marker.controls.append(control)

        self.server.insert(marker, self.onSaveInitialEstimate)

        if self.checkAdditionalTfs():
            self.server.insert(marker, self.onSaveInitialEstimate)
        

        self.menu.apply(self.server, marker.name)

    def checkAdditionalTfs(self):
        return self.config['additional_tfs'] is not None

if __name__ == "__main__":
    # Parse command line arguments
    ap = argparse.ArgumentParser(description='Create first guess. Sets up rviz interactive markers that allow the '
                                             'user to define the position and orientation of each of the sensors '
                                             'listed in the config.yml.')
    ap.add_argument("-f", "--filename", type=str, required=True, default="/calibrations/atlascar2"
                                                                         "/atlascar2_first_guess.urdf.xacro",
                    help="Full path and name of the first guess xacro file. Starting from the root of the interactive "
                         "calibration ros package")
    ap.add_argument("-s", "--marker_scale", type=float, default=0.6, help='Scale of the interactive markers.')
    ap.add_argument("-c", "--calibration_file", type=str, required=True, help='full path to calibration file.')

    args = vars(ap.parse_args(args=filterLaunchArguments(sys.argv)))

    # Initialize ROS stuff
    rospy.init_node("set_initial_estimate")

    # createInteractiveMarker(data_collector.world_link)
    # initMenu()
    # menu_handler.reApply(server)
    # server.applyChanges()

    # Launch the application !!
    first_guess = InteractiveFirstGuess(args)
    first_guess.init()
    rospy.spin()
