#!/usr/bin/env python3

import time
from abc import ABC, abstractmethod
from time import sleep
import math

import rospy
import smach
import numpy as np
import smach_ros
from larcc_demos.srv import ActivateStateOutcome, ActivateStateOutcomeRequest, ActivateStateOutcomeResponse
from gripper_action_server.msg import GripperControlAction, GripperControlGoal, GripperControlResult, GripperControlActionFeedback
from larcc_volumetric.msg import OccupiedPercentage
from moveit_msgs.msg import MoveGroupActionFeedback
from control_msgs.msg import  FollowJointTrajectoryAction, FollowJointTrajectoryActionFeedback
from sensor_msgs.msg import JointState
import actionlib
from moveit_msgs.msg import MoveGroupAction, MoveGroupActionFeedback
from states.utils import send_goal_joint_states, send_trajectory_goal_joint_states, define_movement
from states.service_triggered_state import ServiceTriggeredState
from colorama import Fore, Style



# Define state S1
class S1_FastObjectManipulation(smach.State, ServiceTriggeredState):
    def __init__(self):
        outcomes = ['e1_operator_detected']
        input_keys = ['object', 'last_goal_id', 'positions', 'max_joint_values', 'movement_order','last_position_id']
        output_keys = ['object', 'last_goal_id', 'last_position_id']
        self.name = self.__class__.__name__
        ServiceTriggeredState.__init__(self, name=self.name, outcomes=outcomes)
        smach.State.__init__(self, outcomes=outcomes, input_keys=input_keys, output_keys=output_keys)

        # Local variables
        self.person_detected = False
        self.occupation_threshold = 0.015 # 5 percent

        # Load positions
        initial_position = rospy.get_param("/demo/positions/passage_fast")

        max_vel_wrist_1 = rospy.get_param("/robot_description_planning/joint_limits/wrist_1_joint/max_velocity")
        max_vel_wrist_2 = rospy.get_param("/robot_description_planning/joint_limits/wrist_2_joint/max_velocity")
        max_vel_wrist_3 = rospy.get_param("/robot_description_planning/joint_limits/wrist_3_joint/max_velocity")
        max_vel_shoulder_pan = rospy.get_param("/robot_description_planning/joint_limits/shoulder_pan_joint/max_velocity")
        max_vel_shoulder_lift = rospy.get_param("/robot_description_planning/joint_limits/shoulder_lift_joint/max_velocity")
        max_vel_elbow = rospy.get_param("/robot_description_planning/joint_limits/elbow_joint/max_velocity")

        max_joints_vel = np.array([max_vel_elbow, max_vel_shoulder_lift, max_vel_shoulder_pan, max_vel_wrist_1, max_vel_wrist_2, max_vel_wrist_3])

        # waypoint_vel = 1.2
        # max_vel = 1
        # factor_app = 0.4

        # self.movement_order = [ # Move ball in pose 1 to pose 2
        #                         {"pos": ["ball_pose_1_app"], "max_vel": factor_app * max_vel, "way_point_vel": 0, "object": None},
        #                         {"pos": ["ball_pose_1"], "max_vel": factor_app * max_vel, "way_point_vel": 0, "object": None}, # Grab ball in pose 1
        #                         {"pos": ["close"], "max_vel": max_vel, "way_point_vel": 0, "object": None}, # Grab ball in pose 1
        #                         {"pos": ["ball_pose_1_app"], "max_vel": factor_app *max_vel, "way_point_vel": 0, "object": 0}, # Lift ball in pose 1 
        #                         {"pos": ["passage_fast", "ball_pose_2_app"], "max_vel": max_vel, "way_point_vel": waypoint_vel, "object": 0}, # Go to ball pose 2 
        #                         {"pos": ["ball_pose_2"], "max_vel": factor_app * max_vel, "way_point_vel": 0, "object": 0}, # Lower ball in pose 2 
        #                         {"pos": ["open"], "max_vel": max_vel, "way_point_vel": 0, "object": None}, # Leave ball in pose 2
        #                         {"pos": ["ball_pose_2_app"], "max_vel": factor_app * max_vel, "way_point_vel": 0, "object": None}, # Go up in ball in pose 2
        #                         # Move block in pose 1 to pose 2
        #                         {"pos": ["passage_fast", "block_pose_1_app"], "max_vel": max_vel, "way_point_vel": waypoint_vel, "object": None}, # Grab block in pose 1
        #                         {"pos": ["block_pose_1"], "max_vel": factor_app * max_vel, "way_point_vel": 0, "object": None}, # lower to block in pose 1
        #                         {"pos": ["close"], "max_vel": max_vel, "way_point_vel": 0, "object": None}, # Grab block in pose 1
        #                         {"pos": ["block_pose_1_app"], "max_vel": factor_app * max_vel, "way_point_vel": 0, "object": 1}, # Lift block in pose 1 
        #                         {"pos": ["passage_fast", "block_pose_2_app"], "max_vel": max_vel, "way_point_vel": waypoint_vel, "object": 1}, # Go to block pose 2 
        #                         {"pos": ["block_pose_2"], "max_vel": factor_app * max_vel, "way_point_vel": 0, "object": 1}, # Lower block in pose 2 
        #                         {"pos": ["open"], "max_vel": max_vel, "way_point_vel": 0, "object": None}, # Leavdefine_movemente block in pose 2
        #                         {"pos": ["block_pose_2_app"], "max_vel": factor_app * max_vel, "way_point_vel": 0, "object": None}, # Go up in block in pose 2
        #                         # Move ball in pose 2 to pose 1
        #                         {"pos": ["ball_pose_2_app", "ball_pose_2"], "max_vel": factor_app * max_vel, "way_point_vel": 0, "object": None}, # Grab ball in pose 2
        #                         {"pos": ["close"], "max_vel": max_vel, "way_point_vel": 0, "object": None}, # Grab ball in pose 2
        #                         {"pos": ["ball_pose_2_app"], "max_vel": factor_app * max_vel, "way_point_vel": 0, "object": 0}, # Lift ball in pose 2 
        #                         {"pos": ["passage_fast", "ball_pose_1_app"], "max_vel": max_vel, "way_point_vel": waypoint_vel, "object": 0}, # Go to ball pose 2 
        #                         {"pos": ["ball_pose_1"], "max_vel": factor_app * max_vel, "way_point_vel": 0, "object": 0}, # Lower ball in pose 1 
        #                         {"pos": ["open"], "max_vel": max_vel, "way_point_vel": 0, "object": None}, # Leave ball in pose 1
        #                         {"pos": ["ball_pose_1_app"], "max_vel":factor_app * max_vel, "way_point_vel": 0, "object": None}, # Go up in ball in pose 1
        #                         # Move block in pose 2 to pose 1
        #                         {"pos": ["passage_fast", "block_pose_2_app"], "max_vel": max_vel, "way_point_vel": waypoint_vel, "object": None}, # Grab block in pose 2
        #                         {"pos": ["block_pose_2"], "max_vel": factor_app * max_vel, "way_point_vel": 0, "object": None}, # lower to block in pose 2
        #                         {"pos": ["close"], "max_vel": max_vel, "way_point_vel": 0, "object": None}, # Grab block in pose 2
        #                         {"pos": ["block_pose_2_app"], "max_vel":factor_app * max_vel, "way_point_vel": 0, "object": 1}, # Lift block in pose 2 
        #                         {"pos": ["passage_fast", "block_pose_1_app"], "max_vel": max_vel, "way_point_vel": waypoint_vel, "object": 1}, # Go to block pose 2 
        #                         {"pos": ["block_pose_1"], "max_vel": factor_app * max_vel, "way_point_vel": 0, "object": 1}, # Lower block in pose 1
        #                         {"pos": ["open"], "max_vel": max_vel, "way_point_vel": 0, "object": None}, # Leave block in pose 1
        #                         {"pos": ["block_pose_1_app"], "max_vel": factor_app * max_vel, "way_point_vel": 0, "object": None}, # Go up in block in pose 1
        #                         ]
        self.movement_id = 0
        self.movement_status = 3
        self.joint_states = None
        self.list_needs_app = ["ball_pose_2", "ball_pose_1", "block_pose_2", "block_pose_1"]

        # ROS communications
        rospy.Subscriber("/occupied_percentage", OccupiedPercentage, self.occupiedPercentageCallback) # subscribe to the volumetric monitoring messages
        rospy.Subscriber('/scaled_pos_joint_traj_controller/follow_joint_trajectory/feedback', FollowJointTrajectoryActionFeedback, self.robot_feedback_callback)
        rospy.Subscriber('/gripper_action_server/feedback', GripperControlActionFeedback, self.gripper_feedback_callback)
        rospy.Subscriber('/joint_states', JointState, self.joint_states_callback)
        # self.arm_client = actionlib.SimpleActionClient('/move_group', MoveGroupAction)
        self.arm_client = actionlib.SimpleActionClient('/scaled_pos_joint_traj_controller/follow_joint_trajectory', 
                                               FollowJointTrajectoryAction)
        self.gripper_client = actionlib.SimpleActionClient('/gripper_action_server', GripperControlAction)
        self.arm_client.wait_for_server()
        self.gripper_client.wait_for_server()

        send_trajectory_goal_joint_states(client = self.arm_client, 
                                            initial_point=self.joint_states, 
                                            points=[initial_position], 
                                            max_joint_vel=max_joints_vel, 
                                            scale_vel=[1], 
                                            scale_vel_waypoint=[0])

        goal = GripperControlGoal(goal="open", speed=100)
        self.gripper_client.send_goal(goal)

        self.arm_client.wait_for_result()
        self.gripper_client.wait_for_result()

    def joint_states_callback(self, msg):
        self.joint_states = msg.position
    
    def gripper_feedback_callback(self, feedback):

        if feedback.feedback.status == 1:
            self.movement_status = 3
        else:
            self.movement_status = 1

        # print(feedback.feedback.status)

    def robot_feedback_callback(self, feedback):

        self.movement_status = feedback.status.status

    def occupiedPercentageCallback(self, msg):

        self.person_detected = msg.percentage > self.occupation_threshold

        if self.person_detected:
            print('State ' + self.name + Fore.RED + 'Person detected in ROI!' + Style.RESET_ALL)

    def execute(self, userdata):

        movement_order =  userdata.movement_order
        max_joints_vel = userdata.max_joint_values
        positions = userdata.positions

        last_goal = self.joint_states
        is_first_movement = True
        last_goal_time = None
        self.movement_status = 3
        self.movement_id = userdata.last_goal_id
        last_position_id = userdata.last_position_id
        print(f"State {self.name} - {userdata.last_goal_id} - {movement_order [userdata.last_goal_id]}")
        sleep(0.6)
        # return ServiceTriggeredState.execute(self, userdata)
        st = rospy.Time.now()
        goal_message = None
        rate = rospy.Rate(10)
        while not rospy.is_shutdown():
            # print(math.dist(last_goal, self.joint_states))
            if self.movement_status == 3:
                if goal_message is not None:
                    print(f"time elapsed: {rospy.Time.now() - st}")
                    goal_times = [p.time_from_start for p in goal_message.trajectory.points]
                    print(goal_times)
                    print([t > goal_times[0] for t in goal_times])
                    goal_message = None
                
                first_order = movement_order[self.movement_id]["pos"][0]
                if first_order== "open" or first_order == "close":
                    goal_message = None # Puts arm goal message to None in order to know that the gripper was being moved
                    speed = 255
                    goal = GripperControlGoal(goal=first_order, speed=speed)
                    self.gripper_client.send_goal(goal)
                else:

                    initial_pos = self.joint_states

                    points = []
                    for p in movement_order[self.movement_id]["pos"]:
                        
                        if p == "waypoint":
                            p = "passage_fast"

                        points.append(positions[p]) 
     
                    if is_first_movement:
                        joints_dists = [math.dist(position, initial_pos) for position in points]
                        min_index = joints_dists.index(min(joints_dists))
                        idx = max([min([last_position_id, min_index]), 0])
                        points = points[idx:len(points)]
                        
                        is_first_movement = False
                    
                    scale_vel = movement_order[self.movement_id]["max_vel"]
                    scale_vel_waypoint = movement_order[self.movement_id]["way_point_vel"]
                    goal_message = send_trajectory_goal_joint_states(client = self.arm_client, 
                                                                    initial_point=initial_pos, 
                                                                    points=points, 
                                                                    max_joint_vel=max_joints_vel, 
                                                                    scale_vel=scale_vel, 
                                                                    scale_vel_waypoint=scale_vel_waypoint)
                                    
                    last_goal = points[-1]
                    st = rospy.Time.now()


                userdata.object = movement_order[self.movement_id]["object"]

                userdata.last_goal_id = self.movement_id

                self.movement_id += 1
                
                if self.movement_id >= len(movement_order):
                    self.movement_id = 0

            if goal_message is not None:
                last_goal_time = rospy.Time.now() - st
                # print(f"time elapsed: {rospy.Time.now() - st}")
                # goal_times = [p.time_from_start for p in goal_message.trajectory.points]
                # print(goal_times)
                # print([t > goal_times[0] for t in goal_times])
            else:
                last_goal_time = None

            if self.person_detected and math.dist(last_goal, self.joint_states) >= 0.3: # Transition to state S2
                if goal_message is not None:
                    last_goal_time = rospy.Time(last_goal_time.secs, last_goal_time.nsecs)
                    goal_times = [p.time_from_start for p in goal_message.trajectory.points]
                    userdata.last_position_id = [t >= last_goal_time for t in goal_times].index(True) - 1
                
                self.arm_client.cancel_goal()
                self.arm_client.wait_for_result()
                rate.sleep
                print(f"State {self.name} - {userdata.last_goal_id} - {movement_order[userdata.last_goal_id]}")
                return 'e1_operator_detected'

            rate.sleep()