#!/usr/bin/env python3
import random
import tf
import rospy
import numpy as np
from  scipy.spatial.distance import euclidean as distance
from scipy.spatial.transform import Rotation as R
from visualization_msgs.msg import Marker
from geometry_msgs.msg import Point

class DetectedVolume:
    def __init__(self, side_lengths, parent_name, child_name, world_frame):
        self.side_lengths = np.array(side_lengths)
        self.parent_name = parent_name
        self.child_name = child_name
        self.world_frame = world_frame
        self.marker_color = [random.random() for _ in range(3)]
        # Definition of volume shape
        if (self.parent_name == 'Spine' and self.child_name == 'Spine1') \
            or (self.parent_name == 'LowerBack' and self.child_name == 'Spine'):
            self.shape = 'rectangular_prism'
        else:
            self.shape = 'cylinder'


    def is_inside(self, pointcloud_world, link_names, visualize=False):

        if self.shape == 'rectangular_prism':
            # Get the vertices of the prism
            vertices = self.compute_vertices(link_names)

            # Compute the points inside the prism
            inside = pointcloud_world[self.is_inside_prism(pointcloud_world, vertices)]

            # Visualize if requested
            if visualize:
                link_names[self.parent_name]['joints'][self.child_name].update(
                    {'marker': self.create_marker_rect_prism(link_names[self.parent_name]['index'], vertices)})
            return inside

        elif self.shape == 'cylinder':
            # Precompute vector and constant
            vec = np.array(np.array(link_names[self.child_name]['pose'][:3]) 
                           - np.array(link_names[self.parent_name]['pose'][:3]))
            vec = vec + vec * self.side_lengths[1]
            const = self.side_lengths[0] * np.linalg.norm(vec)

            # Compute the points inside the cylinder
            inside = pointcloud_world[(np.dot(pointcloud_world - np.array(link_names[self.parent_name]['pose'][:3]), vec) >= 0) &
                                    (np.dot(pointcloud_world - link_names[self.child_name]['pose'][:3], vec) <= 0) &
                                    (np.linalg.norm(np.cross(pointcloud_world - np.array(link_names[self.parent_name]['pose'][:3]), vec), axis=1) <= const)]

            
            # Visualize if requested
            if visualize:
                link_names[self.parent_name]['joints'][self.child_name].update(
                    {'marker': self.create_marker_cylinder(link_names[self.parent_name]['index'],
                                                           np.array(link_names[self.parent_name]['pose'][:3]),
                                                           np.array(link_names[self.child_name]['pose'][:3]))})
            return inside
            
            



    def compute_vertices(self, link_names):
        """Compute the vertices of the prism by calculating the direction vector of the z-axis

        Args:
            link_names (dict): dictionary with real time pose of every joint

        Returns:
            vertices (np.array): numpy array with the 8 vertices of the prism
        """

        # Calculate the direction vector of the z-axis
        z = np.array(link_names[self.child_name]['pose'][:3]) - np.array(link_names[self.parent_name]['pose'][:3])

        # Calculate the distance between the two points
        depth = np.linalg.norm(z) 

        # Normalize the z-axis vector
        z = z / depth

        depth = depth + self.side_lengths[2]

        # Define the direction vector of the y-axis
        if self.parent_name == 'Spine':
            y = np.subtract(np.array(link_names['LeftArm']['pose'][:3]),
                            np.array(link_names['RightArm']['pose'][:3])) 
        elif self.parent_name == 'LowerBack':
            y = np.subtract(np.array(link_names['LeftUpLeg']['pose'][:3]),
                            np.array(link_names['RightUpLeg']['pose'][:3]))

        # Calculate the x-axis vector as the cross product of y and z
        x = np.cross(y, z)

        # Normalize the x and y axis vectors
        x = x / np.linalg.norm(x)
        y = y / np.linalg.norm(y)

        # Define the four corner points of the top face
        b1 = link_names[self.parent_name]['pose'][:3] + (self.side_lengths[0]/2)*x + (self.side_lengths[1]/2)*y
        b2 = link_names[self.parent_name]['pose'][:3] - (self.side_lengths[0]/2)*x + (self.side_lengths[1]/2)*y
        b3 = link_names[self.parent_name]['pose'][:3] - (self.side_lengths[0]/2)*x - (self.side_lengths[1]/2)*y
        b4 = link_names[self.parent_name]['pose'][:3] + (self.side_lengths[0]/2)*x - (self.side_lengths[1]/2)*y

        # Define the four corner points of the bottom face
        t1 = b1 + depth*z
        t2 = b2 + depth*z
        t3 = b3 + depth*z
        t4 = b4 + depth*z

        # Define the eight vertices of the rectangular prism
        vertices = np.array([t1, t2, t3, t4, b1, b2, b3, b4])

        return vertices


    def is_inside_prism(self, points, prism_vertices):
        """Verifies if the point is inside the prism by checking if the point is on the inside of the 6 faces

        Args:
            point (np.array): point to verify if it is inside the prism
            prism_vertices (np.array): vertices of the prism

        Returns:
            (bool): Returns True if the point is inside the cylinder
        """

        t1 = prism_vertices[0]
        t2 = prism_vertices[1]
        t3 = prism_vertices[2]
        t4 = prism_vertices[3]
        b1 = prism_vertices[4]
        b2 = prism_vertices[5]
        b3 = prism_vertices[6]
        b4 = prism_vertices[7]



        dir1 = (t1-b1)
        size1 = np.linalg.norm(dir1)
        dir1 = dir1 / size1

        dir2 = (b2-b1)
        size2 = np.linalg.norm(dir2)
        dir2 = dir2 / size2

        dir3 = (b4-b1)
        size3 = np.linalg.norm(dir3)
        dir3 = dir3 / size3


        cube3d_center = np.mean(prism_vertices, axis=0)

        dir_vec = points - cube3d_center

        res1 = np.where( (np.absolute(np.dot(dir_vec, dir1)) * 2) <= size1 )[0]
        res2 = np.where( (np.absolute(np.dot(dir_vec, dir2)) * 2) <= size2 )[0]
        res3 = np.where( (np.absolute(np.dot(dir_vec, dir3)) * 2) <= size3 )[0]

        return list(set(res1).intersection(res2, res3) )

        
    def create_marker_rect_prism(self, index, vertices):
        """Create a marker for every prism vertex

        Args:
            index (int): specific index for each joint
            vertices (np.array): array with the coordinates of every vertex

        Returns:
            (Marker): Marker with every vertex
        """
        # Create a Line Strip marker
        marker = Marker()
        marker.header.frame_id = self.world_frame
        marker.id = index
        marker.type = Marker.LINE_STRIP
        marker.action = Marker.MODIFY

        marker.pose.position.x = 0
        marker.pose.position.y = 0
        marker.pose.position.z = 0

        marker.pose.orientation.x = 0
        marker.pose.orientation.y = 0
        marker.pose.orientation.z = 0
        marker.pose.orientation.w = 1

        marker.scale.x = 0.01

        marker.color.r = self.marker_color[0]
        marker.color.g = self.marker_color[1]
        marker.color.b = self.marker_color[2]
        marker.color.a = 1

        for vertix in vertices:
            marker.points.append(Point(vertix[0], vertix[1], vertix[2]))

        return marker
    
    
    def create_marker_cylinder(self, index, start_pose, end_pose):
        """Create a marker for every prism vertex

        Args:
            index (int): specific index for each joint
            start_pose (np.array): array with the pose of the center point of the top face of the cylinder
            end_pose (np.array): array with the pose of the center point of the bottom face of the cylinder

        Returns:
            (Marker): Marker with the cylinder
        """
        # Create a Cylinder marker
        marker = Marker()
        marker.header.frame_id = self.world_frame
        marker.id = index
        marker.type = Marker.CYLINDER
        marker.action = Marker.MODIFY

        marker.pose.position.x = start_pose[0] 
        marker.pose.position.y = start_pose[1]
        marker.pose.position.z = start_pose[2]

        # Calculating orientation
        distance_vector = end_pose - start_pose
        direction = distance_vector / np.linalg.norm(distance_vector)
        quaternion = R.from_rotvec(direction).as_quat()


        marker.pose.orientation.x = quaternion[0]
        marker.pose.orientation.y = quaternion[1]
        marker.pose.orientation.z = quaternion[2]
        marker.pose.orientation.w = quaternion[3]

        marker.scale.x = self.side_lengths[0]
        marker.scale.y = self.side_lengths[0]
        marker.scale.z = distance(start_pose, end_pose) + self.side_lengths[1]

        marker.color.r = self.marker_color[0]
        marker.color.g = self.marker_color[1]
        marker.color.b = self.marker_color[2]
        marker.color.a = 1


        return marker
        


