#!/usr/bin/env python3

import time
from abc import ABC, abstractmethod
from time import sleep
import math

import rospy
import smach
import numpy as np
import smach_ros
from larcc_demos.srv import ActivateStateOutcome, ActivateStateOutcomeRequest, ActivateStateOutcomeResponse
from gripper_action_server.msg import GripperControlAction, GripperControlGoal, GripperControlResult, GripperControlActionFeedback
from larcc_volumetric.msg import OccupiedPercentage
from moveit_msgs.msg import MoveGroupActionFeedback
from control_msgs.msg import  FollowJointTrajectoryAction, FollowJointTrajectoryActionFeedback
from sensor_msgs.msg import JointState
import actionlib
from moveit_msgs.msg import MoveGroupAction, MoveGroupActionFeedback
from states.utils import send_goal_joint_states, send_trajectory_goal_joint_states, define_movement
from states.service_triggered_state import ServiceTriggeredState
from colorama import Fore, Style


# Define state S3
class S3_GiveObject(smach.State, ServiceTriggeredState):
    def __init__(self):
        outcomes = ['e9_object_available']
        input_keys = ['object', 'last_goal_id', 'positions', 'max_joint_values', 'movement_order']
        self.name = self.__class__.__name__
        ServiceTriggeredState.__init__(self, name=self.name, outcomes=outcomes)
        smach.State.__init__(self, outcomes=outcomes, input_keys=input_keys)

        self.positions = rospy.get_param("/demo/positions")
        self.movement_status = 3

        self.movement_order = ["passage_fast", "interact"]
        self.movement_id = 0
        self.joint_states = None

        waypoint_vel = 0.5
        max_vel = 0.5
        
        self.movement_order = [ # Move ball in pose 1 to pose 2
                                 {"pos": ["passage_fast", "interact"], "max_vel": [max_vel, max_vel], "way_point_vel": [waypoint_vel], "object": None},
        ]

        # ROS communications
        rospy.Subscriber('/scaled_pos_joint_traj_controller/follow_joint_trajectory/feedback', FollowJointTrajectoryActionFeedback, self.robot_feedback_callback)
        rospy.Subscriber('/joint_states', JointState, self.joint_states_callback)
        # self.arm_client = actionlib.SimpleActionClient('/move_group', MoveGroupAction)
        self.arm_client = actionlib.SimpleActionClient('/scaled_pos_joint_traj_controller/follow_joint_trajectory', 
                                               FollowJointTrajectoryAction)
        self.gripper_client = actionlib.SimpleActionClient('/gripper_action_server', GripperControlAction)
        self.arm_client.wait_for_server()
        self.gripper_client.wait_for_server()

    def joint_states_callback(self, msg):
        self.joint_states = msg.position

    def robot_feedback_callback(self, feedback):
        
        self.movement_status = feedback.status.status
        # if feedback.status.status == 3:
        #     print("Received action feedback: {}".format(feedback))


    #TODO implement specific state execute
    def execute(self, userdata):
        # return ServiceTriggeredState.execute(self, userdata)
        self.movement_status = 0
        self.movement_id = 0
        max_joints_vel = userdata.max_joint_values
        positions = userdata.positions
        sleep(1)
        rate = rospy.Rate(10)

        initial_pos = self.joint_states

        points = [positions[p] for p in self.movement_order[self.movement_id]["pos"]]

        scale_vel = self.movement_order[self.movement_id]["max_vel"]
        scale_vel_waypoint = self.movement_order[self.movement_id]["way_point_vel"]
        send_trajectory_goal_joint_states(client = self.arm_client, 
                                            initial_point=initial_pos, 
                                            points=points, 
                                            max_joint_vel=max_joints_vel, 
                                            scale_vel=scale_vel, 
                                            scale_vel_waypoint=scale_vel_waypoint)
        
        while not rospy.is_shutdown():
            if self.movement_status == 3: # To state 4
                return 'e9_object_available'

            rate.sleep()
