#!/usr/bin/env python3

import signal
import sys
import rospy
import actionlib
import numpy as np
import math
from moveit_msgs.msg import MoveGroupAction, MoveGroupActionGoal, MoveGroupActionFeedback, MoveGroupActionResult
from moveit_msgs.msg import MoveGroupGoal
from moveit_msgs.msg import Constraints, JointConstraint
from control_msgs.msg import FollowJointTrajectoryGoal, FollowJointTrajectoryAction, JointTolerance
from trajectory_msgs.msg import JointTrajectoryPoint, JointTrajectory

def define_movement(objective, goals, positions, current_point):

    points = [current_point]
    velocities = [0]
    object = objective["object"]
    position = objective["position"]

    for i, goal in enumerate(goals["pos"]):

        if goals["pos"].find("pose") >= 0:          
            points.append(positions[f"{object}_{goal}_{position}"])
            velocities.append(goals["waypoint_vel"][i])
        else:
            pass

    return points, velocities

def send_goal_joint_states(client, positions, movement, movement_id, vel=1, acel = 1):

    joint_values = positions[movement[movement_id]]

    goal = MoveGroupGoal()
    constraints = Constraints()


    joint_names = ["elbow_joint", "shoulder_lift_joint", "shoulder_pan_joint", "wrist_1_joint", "wrist_2_joint",
                   "wrist_3_joint"]

    for joint_name, joint_value in zip(joint_names, joint_values):
        constraint = JointConstraint()
        constraint.joint_name = joint_name
        constraint.position = joint_value
        constraint.tolerance_above = 0.0001
        constraint.tolerance_below = 0.0001
        constraint.weight = 1.0
        constraints.joint_constraints.append(constraint)

    goal.request.goal_constraints = [constraints]
    goal.request.group_name = 'manipulator'
    goal.request.num_planning_attempts = 10
    goal.request.allowed_planning_time = 10
    goal.request.max_velocity_scaling_factor = vel
    goal.request.max_acceleration_scaling_factor = acel
    goal.request.workspace_parameters.header.stamp = rospy.Time.now()
    goal.request.workspace_parameters.header.frame_id = 'world'

    # Sends the goal to the action server.
    client.send_goal(goal)

def send_trajectory_goal_joint_states(client, initial_point, points, max_joint_vel, scale_vel = 1, scale_vel_waypoint = 0.2):

    goal = FollowJointTrajectoryGoal()
  
    points.insert(0, initial_point)

    joint_names = ["elbow_joint", "shoulder_lift_joint", "shoulder_pan_joint", "wrist_1_joint", "wrist_2_joint",
                   "wrist_3_joint"]

    goal.path_tolerance = []

    goal.trajectory.joint_names = joint_names
    count = 0.3
    for i, point in enumerate(points):
        point1 = np.array(point)
        point2 = np.array(points[i-1])
        delta_theta = point[2] - points[i-1][2]


        if i>0 and i<len(points) - 1:
            point3 = np.array(points[i+1])
            time_span = max([abs(2 * np.max(np.abs(point2 - point1)/ (max_joint_vel * scale_vel[i-1]))), 0.6])
            delta_theta_objective = point3 - point1
            vel_mid = delta_theta_objective / time_span
            vel_middle = (vel_mid / max(abs(vel_mid))) * scale_vel_waypoint[i-1]

        elif i == len(points) - 1:

            time_span = max([abs(2 * np.max(np.abs(point2 - point1)/ (max_joint_vel * scale_vel[i-1]))), 0.6])
            vel_middle = [0] * 6
        else:
            time_span = 0.1
            vel_middle = [0] * 6

        point_msg = JointTrajectoryPoint()
        point_msg.velocities = vel_middle

        point_msg.positions = point
        count += time_span
 
        point_msg.time_from_start = rospy.Time(count)
        goal.trajectory.points.append(point_msg)
        
    goal.goal_time_tolerance = rospy.Time(count)

    goal.trajectory.header.stamp = rospy.Time.now()
    goal.trajectory.header.frame_id = 'world'
    # print(goal)
    client.send_goal(goal)

    return goal

# def send_arm_gripper_goal(arm_client, gripper_client):
    
    # send_trajectory_goal_joint_states(client = arm_client, 
    #                                   initial_point=initial_pos, 
    #                                   points=points, 
    #                                   max_joint_vel=self.max_joints_vel, 
    #                                   scale_vel=scale_vel, 
    #                                   scale_vel_waypoint=scale_vel_waypoint)