#!/usr/bin/env python3

import time
from abc import ABC, abstractmethod
from time import sleep

import rospy
import smach
import numpy as np
import smach_ros
from larcc_demos.srv import ActivateStateOutcome, ActivateStateOutcomeRequest, ActivateStateOutcomeResponse
from gripper_action_server.msg import GripperControlAction, GripperControlGoal, GripperControlResult, GripperControlActionFeedback
from larcc_volumetric.msg import OccupiedPercentage
from moveit_msgs.msg import MoveGroupActionFeedback
from control_msgs.msg import  FollowJointTrajectoryAction, FollowJointTrajectoryActionFeedback
from sensor_msgs.msg import JointState
import actionlib
from moveit_msgs.msg import MoveGroupAction, MoveGroupActionFeedback
from states.utils import send_goal_joint_states, send_trajectory_goal_joint_states
from states.service_triggered_state import ServiceTriggeredState
from colorama import Fore, Style


# Define state S6
class S6_RecoverObject(smach.State, ServiceTriggeredState):
    def __init__(self):
        outcomes = ['e10_object_recovered']
        input_keys = ['object', 'last_goal_id', 'positions', 'max_joint_values', 'movement_order']
        self.name = self.__class__.__name__
        ServiceTriggeredState.__init__(self, name=self.name, outcomes=outcomes)
        smach.State.__init__(self, outcomes=outcomes, input_keys=input_keys)

        # Load positions
        # self.positions = rospy.get_param("/demo/positions")

        # max_vel_wrist_1 = rospy.get_param("/robot_description_planning/joint_limits/wrist_1_joint/max_velocity")
        # max_vel_wrist_2 = rospy.get_param("/robot_description_planning/joint_limits/wrist_2_joint/max_velocity")
        # max_vel_wrist_3 = rospy.get_param("/robot_description_planning/joint_limits/wrist_3_joint/max_velocity")
        # max_vel_shoulder_pan = rospy.get_param("/robot_description_planning/joint_limits/shoulder_pan_joint/max_velocity")
        # max_vel_shoulder_lift = rospy.get_param("/robot_description_planning/joint_limits/shoulder_lift_joint/max_velocity")
        # max_vel_elbow = rospy.get_param("/robot_description_planning/joint_limits/elbow_joint/max_velocity")

        # self.max_joints_vel = np.array([max_vel_elbow, max_vel_shoulder_lift, max_vel_shoulder_pan, max_vel_wrist_1, max_vel_wrist_2, max_vel_wrist_3])

        waypoint_vel = 0.2
        max_vel = 0.5
        fa = 0.5

        self.joint_states = None
        self.movement_id = 0

        self.movement_status = 3
        self.movement_order = [ # Move ball in pose 1 to pose 2
                            [{"pos": ["ball_recovery_app", "ball_recovery"], "max_vel": [max_vel, fa * max_vel], "way_point_vel": [waypoint_vel]}, # Grab ball in pose 1
                            {"pos": ["close"], "max_vel": max_vel, "way_point_vel": 0}, # Grab ball in pose 1
                            {"pos": ["ball_recovery_app"], "max_vel": [fa * max_vel], "way_point_vel": [0]}, # Lift ball in pose 1 
                            ],
                            [{"pos": ["block_recovery_app", "block_recovery"], "max_vel": [max_vel, fa * max_vel], "way_point_vel": [waypoint_vel]}, # Grab ball in pose 1
                            {"pos": ["close"], "max_vel": max_vel, "way_point_vel": 0}, # Grab ball in pose 1
                            {"pos": ["block_recovery_app"], "max_vel": [fa * max_vel], "way_point_vel": [0]}, # Lift ball in pose 1 
                            ]
                            ]

        # ROS communications
        # rospy.Subscriber("/occupied_percentage", OccupiedPercentage, self.occupiedPercentageCallback) # subscribe to the volumetric monitoring messages
        rospy.Subscriber('/scaled_pos_joint_traj_controller/follow_joint_trajectory/feedback', FollowJointTrajectoryActionFeedback, self.robot_feedback_callback)
        rospy.Subscriber('/gripper_action_server/feedback', GripperControlActionFeedback, self.gripper_feedback_callback)
        rospy.Subscriber('/joint_states', JointState, self.joint_states_callback)
        # self.arm_client = actionlib.SimpleActionClient('/move_group', MoveGroupAction)
        self.arm_client = actionlib.SimpleActionClient('/scaled_pos_joint_traj_controller/follow_joint_trajectory', 
                                               FollowJointTrajectoryAction)
        self.gripper_client = actionlib.SimpleActionClient('/gripper_action_server', GripperControlAction)
        self.arm_client.wait_for_server()
        self.gripper_client.wait_for_server()

    def joint_states_callback(self, msg):
        self.joint_states = msg.position
    
    def gripper_feedback_callback(self, feedback):

        if feedback.feedback.status == 1:
            self.movement_status = 3
        else:
            self.movement_status = 1

        # print(feedback.feedback.status)

    def robot_feedback_callback(self, feedback):

        self.movement_status = feedback.status.status

    #TODO implement specific state execute
    def execute(self, userdata):

        max_joints_vel = userdata.max_joint_values
        positions = userdata.positions
        self.movement_status == 3

        movement_order = self.movement_order[userdata.object]

        rate = rospy.Rate(10)
        while not rospy.is_shutdown():
            
            if self.movement_status == 3 and self.movement_id < len(movement_order):
                first_order = movement_order[self.movement_id]["pos"][0]
                if first_order== "open" or first_order == "close":
                    speed = int(255)
                    goal = GripperControlGoal(goal=first_order, speed=speed)
                    self.gripper_client.send_goal(goal)
                else:
                    initial_pos = self.joint_states
                    points = [positions[p] for p in movement_order[self.movement_id]["pos"]]
                    scale_vel = movement_order[self.movement_id]["max_vel"]
                    scale_vel_waypoint = movement_order[self.movement_id]["way_point_vel"]
                    send_trajectory_goal_joint_states(client = self.arm_client, 
                                                      initial_point=initial_pos, 
                                                      points=points, 
                                                      max_joint_vel=max_joints_vel, 
                                                      scale_vel=scale_vel, 
                                                      scale_vel_waypoint=scale_vel_waypoint)
                    
                self.movement_id += 1

            elif self.movement_status == 3 and self.movement_id >= len([movement_order]): # To state 4
                self.movement_id = 0
                return 'e10_object_recovered'


            rate.sleep()