#!/usr/bin/env python3
import time
from abc import ABC, abstractmethod
from time import sleep

import rospy
import smach
import numpy as np
import smach_ros
from physical_interaction.msg import PhysicalClassification
from hand_gesture_recognition.msg import HandsClassified
from larcc_demos.srv import ActivateStateOutcome, ActivateStateOutcomeRequest, ActivateStateOutcomeResponse
from gripper_action_server.msg import GripperControlAction, GripperControlGoal, GripperControlResult, GripperControlActionFeedback
from larcc_volumetric.msg import OccupiedPercentage
from moveit_msgs.msg import MoveGroupActionFeedback
from control_msgs.msg import  FollowJointTrajectoryAction, FollowJointTrajectoryActionFeedback
from sensor_msgs.msg import JointState
import actionlib
from moveit_msgs.msg import MoveGroupAction, MoveGroupActionFeedback
from states.utils import send_goal_joint_states, send_trajectory_goal_joint_states
from states.service_triggered_state import ServiceTriggeredState
from colorama import Fore, Style


# Define state S4
class S4_WaitForInteraction(smach.State, ServiceTriggeredState):
    def __init__(self):
        outcomes = ['e5_interaction_pull_detected', 'e6_interaction_push_detected',
                    'e7_interaction_twist_detected', 'e2_operator_not_detected']
        self.name = self.__class__.__name__
        input_keys = ['object', 'last_goal_id', 'positions', 'max_joint_values', 'movement_order']
        ServiceTriggeredState.__init__(self, name=self.name, outcomes=outcomes)
        smach.State.__init__(self, outcomes=outcomes, input_keys=input_keys)

        self.occupation_threshold = 0.005
        self.person_detected = False
        self.physical_classification = "None"
        # self.positions = rospy.get_param("/demo/positions")


        # waypoint_vel = 0.2
        max_vel = 0.5
        # factor_app = 0.6
        self.movement_order = [ # Move ball in pose 1 to pose 2
                                {"pos": ["passage_safe"], "max_vel": [max_vel], "way_point_vel": [0], "object": None},
                      ]

        # max_vel_wrist_1 = rospy.get_param("/robot_description_planning/joint_limits/wrist_1_joint/max_velocity")
        # max_vel_wrist_2 = rospy.get_param("/robot_description_planning/joint_limits/wrist_2_joint/max_velocity")
        # max_vel_wrist_3 = rospy.get_param("/robot_description_planning/joint_limits/wrist_3_joint/max_velocity")
        # max_vel_shoulder_pan = rospy.get_param("/robot_description_planning/joint_limits/shoulder_pan_joint/max_velocity")
        # max_vel_shoulder_lift = rospy.get_param("/robot_description_planning/joint_limits/shoulder_lift_joint/max_velocity")
        # max_vel_elbow = rospy.get_param("/robot_description_planning/joint_limits/elbow_joint/max_velocity")

        # self.max_joints_vel = np.array([max_vel_elbow, max_vel_shoulder_lift, max_vel_shoulder_pan, max_vel_wrist_1, max_vel_wrist_2, max_vel_wrist_3])
       
        # ROS communications
        rospy.Subscriber("/occupied_percentage", OccupiedPercentage, self.occupiedPercentageCallback) # subscribe to the volumetric monitoring messages
        rospy.Subscriber('/scaled_pos_joint_traj_controller/follow_joint_trajectory/feedback', FollowJointTrajectoryActionFeedback, self.robot_feedback_callback)
        rospy.Subscriber('/gripper_action_server/feedback', GripperControlActionFeedback, self.gripper_feedback_callback)
        rospy.Subscriber('/joint_states', JointState, self.joint_states_callback)
        rospy.Subscriber("/classification", PhysicalClassification, self.physicalClassificationCallback)
        # self.arm_client = actionlib.SimpleActionClient('/move_group', MoveGroupAction)
        self.arm_client = actionlib.SimpleActionClient('/scaled_pos_joint_traj_controller/follow_joint_trajectory', 
                                               FollowJointTrajectoryAction)
        self.gripper_client = actionlib.SimpleActionClient('/gripper_action_server', GripperControlAction)
        self.arm_client.wait_for_server()
        self.gripper_client.wait_for_server()

    def joint_states_callback(self, msg):
        self.joint_states = msg.position

    def physicalClassificationCallback(self, msg):
        self.physical_classification = msg.classification

    def joint_states_callback(self, msg):
        self.joint_states = msg.position
    
    def gripper_feedback_callback(self, feedback):

        if feedback.feedback.status == 1:
            self.movement_status = 3
        else:
            self.movement_status = 1

        # print(feedback.feedback.status)


    def robot_feedback_callback(self, feedback):
    
        self.movement_status = feedback.status.status
        # if feedback.status.status == 3:
        #     print("Received action feedback: {}".format(feedback))

    def occupiedPercentageCallback(self, msg):
        # print('Received occupied percentage ' + str(msg.percentage))
        self.person_detected = msg.percentage > self.occupation_threshold

        if not self.person_detected:
            print('State ' + self.name + Fore.BLUE + 'ROI is free' + Style.RESET_ALL)


    #TODO implement specific state execute
    def execute(self, userdata):
        # return ServiceTriggeredState.execute(self, userdata)
        max_joints_vel = userdata.max_joint_values
        positions = userdata.positions
        sleep(0.6)

        rate = rospy.Rate(10)
        while not rospy.is_shutdown():
            # if not self.person_detected: # To state 1
            #     return 'e2_operator_not_detected'
            if self.physical_classification.upper() == "PUSH": # To state 2
                # while not (self.physical_classification.upper() == "NONE") and not rospy.is_shutdown():
                #     rate.sleep()
                return 'e6_interaction_push_detected'
            if self.physical_classification.upper() == "PULL": # To state 5
                goal = GripperControlGoal(goal="open", speed=100)
                self.gripper_client.send_goal(goal)
                # while not (self.physical_classification.upper() == "NONE") and not rospy.is_shutdown():
                    # rate.sleep()

                # self.gripper_client.wait_for_result()
                initial_pos = self.joint_states
                points = [positions[p] for p in self.movement_order[0]["pos"]]
                scale_vel = self.movement_order[0]["max_vel"]
                scale_vel_waypoint = self.movement_order[0]["way_point_vel"]
                send_trajectory_goal_joint_states(client = self.arm_client, 
                                                      initial_point=initial_pos, 
                                                      points=points, 
                                                      max_joint_vel=max_joints_vel, 
                                                      scale_vel=scale_vel, 
                                                      scale_vel_waypoint=scale_vel_waypoint)
                self.arm_client.wait_for_result()
                return 'e5_interaction_pull_detected'
                
            rate.sleep()        
