#!/usr/bin/env python3

import argparse
import math
import os
import shutil
import sys
import yaml
from yaml.loader import SafeLoader
from d_volume_class import DetectedVolume
from functools import partial
import numpy as np
import time
import collada
import pandas as pd
import json
from scipy.spatial.transform import Rotation as R
from scipy.spatial import Delaunay
from skimage.measure import block_reduce
import rospy
import rospkg
import tf
from gazebo_msgs.msg import LinkStates
from sensor_msgs.msg import PointCloud2
from sensor_msgs import point_cloud2 as pointc2                                  
from geometry_msgs.msg import Point
from rosgraph_msgs.msg import Clock
from visualization_msgs.msg import Marker, MarkerArray
import message_filters


class DotDict(dict):
    """
    Dot notation access to dictionary attributes, recursively.
    """
    def __getattr__(self, attr):
        value = self.get(attr)
        if isinstance(value, dict):
            return DotDict(value)
        return value

    __setattr__ = dict.__setitem__

    def __delattr__(self, attr):
        del self[attr]

    def __missing__(self, key):
        self[key] = DotDict()
        return self[key]



def link_states_callback(msg, config):
    """Gather link states

    Args:
        msg (LinkStates): message with gazebo link states
        config (dict): mutable variable to store the message
    """
    config['link_state'] = msg
    return
    

def clock_callback(msg, config):
    """Gather clock

    Args:
        msg (clock): message with the clock 
        config (dict): mutable variable to store the message
    """
    config['clock_secs'] = msg.clock.secs
    return


def pointcloud_filter_callback(pc1, pc2, pc3, pc4, config):
    """
    Callback function to filter and process multiple point cloud messages.

    Args:
        pc1 (PointCloud2): Point cloud message from camera 1.
        pc2 (PointCloud2): Point cloud message from lidar 1.
        pc3 (PointCloud2): Point cloud message from lidar 2.
        pc4 (PointCloud2): Point cloud message from lidar 3.
        config (dict): Configuration dictionary containing settings for processing.

    Returns:
        None

    Side Effects:
        - Updates the 'pointcloud_msgs_to_process' dictionary in the 'config' with the point cloud messages.
    """
    config['pointcloud_msgs_timestamp'] = pc1.header.stamp
    for pc in [pc1, pc2, pc3, pc4]:
        config['pointcloud_msgs_to_process'][pc.header.frame_id] = pc
        if config['pointcloud_msgs_timestamp'] < pc.header.stamp:
            config['pointcloud_msgs_timestamp'] = pc.header.stamp

    return


# def voxel_callback(msg, config):
        # closest_interval = float('inf')
        # print(voxel_cache)
        # for msg in voxel_cache:
        #     if msg.header.stamp.to_sec() > config['pointcloud_msgs_timestamp'] and \
        #     msg.header.stamp.to_sec() - config['pointcloud_msgs_timestamp'] < closest_interval:
        #         closest_interval = msg.header.stamp.to_sec() - config['pointcloud_msgs_timestamp']
        #         occupied_voxels = msg
        
        # if closest_interval == float('inf'):
        #     continue


        # if type(occupied_voxels)==type(Marker()):
        #     config['occupied_voxel_points'] = np.array([[point.x, point.y, point.z] for point in occupied_voxels.points]) 
        # else:



    # return


def iterate_children(node, link_names):
    """Retrieve children of node and update dictionary with the respective info

    Args:
        node (collada.Scene.SceneNode): collada node to retrieve children
        link_names (dict): dictionary to store link names
    """
    
    # Add node name to link_names
    link_names.update({node.id: {'joints': {}}})

    # Retrieve child nodes
    child_nodes = getattr(node, 'children', None)

    # For every child node, add it to the children list and iterate over it
    if child_nodes is not None:
        for child_node in child_nodes:
            link_names[node.id]['joints'].update({child_node.id: {'d_volume': None}})
            iterate_children(child_node, link_names)
    else:
        return


def define_d_volumes(link_names, json_file_path, world_frame):
    """
    Define detected volumes for specified link names based on information from a JSON file.

    Args:
        link_names (dict): Dictionary containing link names and their associated information.
        json_file_path (str): Path to the JSON file containing the detected volume information.
        world_frame (str): The frame of reference for the detected volumes.

    Returns:
        None

    Side Effects:
        - Updates the 'd_volume' attribute of the specified links in the 'link_names' dictionary.

    """

    # Open the JSON file
    with open(json_file_path, 'r') as f:
        # Load the contents of the file into a dictionary
        rect_d_volume_info = json.load(f)

    for joint in rect_d_volume_info:
        [parent, child, _] = joint.split('_')
        link_names[parent]['joints'][child]['d_volume'] = DetectedVolume(rect_d_volume_info[joint]['side_lengths'], parent, child, world_frame)


def define_link_names(link_names, collada_path, d_volume_path, world_frame):
    """Define the link names

    Args:
        file_path (str): path to the collada file
    """

    actor_collada = collada.Collada(collada_path)
    iterate_children(actor_collada.scene.nodes[0].children[0], link_names)
    define_d_volumes(link_names, d_volume_path, world_frame)


def filter_link_state(link_states, link_names):
    """
    Updates the link_names dictionary with the index and pose of the specified links.

    Args:
        link_states (LinkStates): The link states object containing information about each link.
        link_names (dict): A dictionary of link names as keys and empty dictionaries as values.

    Returns:
        None

    Raises:
        None
    """
    # Find the index of the link with the specified name
    try:
        for link_name in link_names:
            link_names[link_name].update({"index": link_states.name.index('actor::' + link_name)})
    except ValueError:
        # Link not found, return None
        return None
        
    # Add current pose for each link to link_names 
    for link_name in link_names:
        link_names[link_name].update({"pose": (link_states.pose[link_names[link_name]['index']].position.x,
                                               link_states.pose[link_names[link_name]['index']].position.y,
                                               link_states.pose[link_names[link_name]['index']].position.z,
                                               link_states.pose[link_names[link_name]['index']].orientation.x,
                                               link_states.pose[link_names[link_name]['index']].orientation.y,
                                               link_states.pose[link_names[link_name]['index']].orientation.z,
                                               link_states.pose[link_names[link_name]['index']].orientation.w
                                               )})
    return


def create_marker_pc2(vertices, frame_id, color, id=0, namespace=None):
    """Create a marker for every prism vertex

    Args:
        vertices (np.array): array with the coordinates of every vertex
        frame_id (str): frame id

    Returns:
        (Marker): Marker with every vertex
    """
    # Create a Line Strip marker
    marker = Marker()
    marker.header.frame_id = frame_id
    marker.id = id
    if namespace:
        marker.ns = namespace
    marker.type = Marker.SPHERE_LIST
    marker.action = Marker.MODIFY

    marker.pose.position.x = 0
    marker.pose.position.y = 0
    marker.pose.position.z = 0

    marker.pose.orientation.x = 0
    marker.pose.orientation.y = 0
    marker.pose.orientation.z = 0
    marker.pose.orientation.w = 1

    marker.scale.x = 0.03
    marker.scale.y = 0.03
    marker.scale.z = 0.03

    marker.color.r = color[0]
    marker.color.g = color[1]
    marker.color.b = color[2]
    marker.color.a = 1

    for vertix in vertices:
        marker.points.append(Point(vertix[0], vertix[1], vertix[2]))

    return marker


def create_marker_voxel(points, frame_id, color, resolution, id=0, namespace=None):
    """Create a marker for voxels

    Args:
        vertices (np.array): array with the coordinates of every vertex
        frame_id (str): frame id

    Returns:
        (Marker): Marker with every voxel 
    """
    # Create a Cube List marker
    marker = Marker()
    marker.header.frame_id = frame_id
    marker.id = id
    if namespace:
        marker.ns = namespace
    marker.type = Marker.CUBE_LIST
    marker.action = Marker.MODIFY

    marker.pose.position.x = 0
    marker.pose.position.y = 0
    marker.pose.position.z = 0

    marker.pose.orientation.x = 0
    marker.pose.orientation.y = 0
    marker.pose.orientation.z = 0
    marker.pose.orientation.w = 1

    marker.scale.x = resolution
    marker.scale.y = resolution
    marker.scale.z = resolution

    marker.color.r = color[0]
    marker.color.g = color[1]
    marker.color.b = color[2]
    marker.color.a = 0.3

    for point in points:
        marker.points.append(Point(point[0], point[1], point[2]))

    return marker


def transform(world_frame, sensor_frame, listener, points):
    """Transform points to the world frame

    Args:
        sensor_frame (string): sensor frame in ROS
        listener (tf.listener): listener of ROS tfs
        points (np.array): pointcloud

    Returns:
        transformed_points (np.array): Pointcloud transformed to the world frame
    """

    # Retrieve transform between world and sensor frame and process it into rotation matrix and translation vector
    try:
        (trans_sensor2world,quarterion_sensor2world) = listener.lookupTransform(world_frame, sensor_frame, rospy.Time(0))
    except (tf.LookupException, tf.ConnectivityException):
        return
    rot_sensor2world = R.from_quat(quarterion_sensor2world)
    R_matrix_sensor2world = rot_sensor2world.as_matrix()
    T_vector_sensor2world = np.array([[trans_sensor2world[0]], [trans_sensor2world[1]], [trans_sensor2world[2]]])

    # Transform prism vertices
    transformed_points = np.matmul(R_matrix_sensor2world, np.transpose(points)) + T_vector_sensor2world
    
    return transformed_points


def in_hull(points, hull):
    """
    Test if points in `p` are in `hull`

    `p` should be a `NxK` coordinates of `N` points in `K` dimensions
    `hull` is either a scipy.spatial.Delaunay object or the `MxK` array of the 
    coordinates of `M` points in `K`dimensions for which Delaunay triangulation
    will be computed
    """
    if not isinstance(hull,Delaunay):
        hull = Delaunay(hull)

    return points[hull.find_simplex(points)>=0]


def verify_points_in_voxels(points, voxel_positions, voxel_resolution, threshold,
                            visualize=False, metrics_pub=None, gt_pub=None, frame_id='world'):
    """
    Verifies the presence of points within voxels and calculates precision, recall, and F1 score.

    Args:
        points (numpy.ndarray): Array of point cloud coordinates.
        voxel_positions (numpy.ndarray): Array of voxel positions.
        voxel_resolution (float): Resolution of the voxels.
        threshold (float): Distance threshold for determining false negatives, due to float comparison.
        visualize (bool, optional): Flag to enable visualization. Defaults to False.
        metrics_pub (Publisher, optional): Publisher for metrics visualization. Defaults to None.
        gt_pub (Publisher, optional): Publisher for ground truth visualization. Defaults to None.
        frame_id (str, optional): Frame ID for visualization. Defaults to 'world'.

    Returns:
        tuple: A tuple containing precision, recall, and F1 score.

    """

    # Create a convex hull from the points
    convex_hull = Delaunay(points)

    # Check if voxels are within the convex hull
    voxels_in_hull = in_hull(np.round(voxel_positions, 2), convex_hull).tolist()

    points = points.tolist()
    # Calculate voxel coordinates of the point cloud
    pointcloud_voxel_coordinates = []
    for point in points:
        voxel_x = math.floor(point[0] / voxel_resolution) * voxel_resolution + voxel_resolution / 2
        voxel_y = math.floor(point[1] / voxel_resolution) * voxel_resolution + voxel_resolution / 2
        voxel_z = math.floor(point[2] / voxel_resolution) * voxel_resolution + voxel_resolution / 2
        voxel_coordinate = [voxel_x, voxel_y, voxel_z]
        if voxel_coordinate not in pointcloud_voxel_coordinates:
            pointcloud_voxel_coordinates.append(voxel_coordinate)

    true_positives = []

    for voxel in voxels_in_hull:
        for point in points:
            if voxel[0] - voxel_resolution/2 < point[0] < voxel[0] + voxel_resolution/2 and \
            voxel[1] - voxel_resolution/2 < point[1] < voxel[1] + voxel_resolution/2 and \
            voxel[2] - voxel_resolution/2 < point[2] < voxel[2] + voxel_resolution/2:
                true_positives.append(voxel)
                break

    false_positives = [voxel for voxel in voxels_in_hull if voxel not in true_positives]
    false_negatives = [voxel for voxel in pointcloud_voxel_coordinates if all(math.sqrt(sum((a - b) ** 2 for a, b in zip(voxel, tp))) > threshold for tp in true_positives)]

    if visualize:

        # Visualize metrics
        metrics_marker_array = MarkerArray()
        metrics_marker_array.markers.append(create_marker_voxel(true_positives, frame_id, [0, 1.0, 0], voxel_resolution, 0, 'true_positives'))
        metrics_marker_array.markers.append(create_marker_voxel(false_positives, frame_id, [0, 0, 1.0], voxel_resolution, 1, 'false_positives'))
        metrics_marker_array.markers.append(create_marker_voxel(false_negatives, frame_id, [1.0, 0, 0], voxel_resolution, 2, 'false_negatives'))
        metrics_pub.publish(metrics_marker_array)

        # Visualize ground truth
        gt_marker_array = MarkerArray()
        gt_marker_array.markers.append(create_marker_pc2(voxels_in_hull, frame_id, [0.99, 0.91, 0.08], 1, 'retrieved'))
        gt_pub.publish(gt_marker_array)

    # Calculate precision
    if len(voxels_in_hull) == 0:
        precision = 0
    else:
        precision = len(true_positives) / len(voxels_in_hull)

    # Calculate recall
    recall = len(true_positives) / (len(true_positives) + len(false_negatives))

    # Calculate F1 score
    if precision == 0 and recall == 0:
        f1_score = 0
    else:
        f1_score = (2 * precision * recall) / (precision + recall)

    return precision, recall, f1_score



def main():
    # Create argparser
    parser = argparse.ArgumentParser(description='Retrieve metrics for volumetric detection.')
    parser.add_argument('-en', '--experiment_name', type=str, required=True)
    parser.add_argument('-ow', '--overwrite', action='store_true')
    arglist = [x for x in sys.argv[1:] if not x.startswith('__')]
    args = vars(parser.parse_args(args=arglist))

    # Define rospack
    rospack = rospkg.RosPack()

    # Define paths
    experiment_path = rospack.get_path('larcc_volumetric') + f'/results/{args["experiment_name"]}'

    if os.path.exists(experiment_path):
        if args['overwrite']:
            shutil.rmtree(experiment_path)
        else:
            print(f'{experiment_path} already exits. ')
            raise Exception('Experiment name already exists. If you want to overwrite, use flag -ow')
    
    # create folder to the results.
    os.makedirs(experiment_path)

    # Define config
    cfg_path = rospack.get_path('larcc_volumetric') + f'/config/{args["experiment_name"]}.yaml'
    with open(cfg_path) as f:
        cfg = DotDict(yaml.load(f, Loader=SafeLoader))

    # Define initial parameters
    link_names = dict()
    config = dict()
    config['link_state'] = None
    config['occupied_voxel_points'] = None
    config['empty_voxel_points'] = None
    config['clock_secs'] = None
    config['pointcloud_msgs_to_process'] = dict()
    config['pointcloud_msgs_timestamp'] = None
    full_inside_points = []

    # Use cfg class for initial params
    collada_path = cfg.collada_path
    d_volume_path = rospack.get_path('larcc_volumetric') + f'/config/joint_volumes/{cfg.joint_volumes}.json'
    world_frame = cfg.world_frame
    visualize = cfg.visualize
    resolution = cfg.resolution

    # Define occupied voxel topic name
    if cfg.volumetric_algorithm == 'octomap':
        occupied_voxel_topic = '/octomap_point_cloud_centers'
        occupied_voxel_sub = message_filters.Subscriber(occupied_voxel_topic, PointCloud2)
        # empty_voxel_topic = '/free_cells_vis_array'
        # empty_voxel_sub = message_filters.Subscriber(empty_voxel_topic, MarkerArray)
    elif cfg.volumetric_algorithm == 'voxblox':
        occupied_voxel_topic = '/voxblox_node/surface_pointcloud'
        occupied_voxel_sub = message_filters.Subscriber(occupied_voxel_topic, PointCloud2)
    elif cfg.volumetric_algorithm == 'skimap':
        occupied_voxel_topic = '/skimap_live_pc/live_map'
        occupied_voxel_sub = message_filters.Subscriber(occupied_voxel_topic, Marker)
    else:
        raise Exception(f'Unknown volumetric algorithm {cfg.volumetric_algorithm}')
    
    occupied_voxel_cache = message_filters.Cache(occupied_voxel_sub, 5)
    # empty_voxel_cache = message_filters.Cache(empty_voxel_sub, 5)

    # Define the link names
    define_link_names(link_names, collada_path, d_volume_path, world_frame)

    # Create an empty DataFrame with column names
    df = pd.DataFrame(columns=["Time", "Precision", "Recall", "F1 Score"])

    # Initialize the node
    rospy.init_node('link_state_filter')

    # If visualize, create a publisher to visualize the marker array
    if visualize:
        marker_array_pub = rospy.Publisher('visualization_marker_array', MarkerArray, queue_size=10)
        marker_array = MarkerArray()
        marker_inside_pub = rospy.Publisher('marker_inside', Marker, queue_size=10)
    metrics_marker_array_pub = rospy.Publisher('metrics_marker_array', MarkerArray, queue_size=10)
    gt_marker_array_pub = rospy.Publisher('gt_marker_array', MarkerArray, queue_size=10)

    # Subscribe to the link states, clock and the point cloud
    link_states_callback_partial = partial(link_states_callback, config=config)
    rospy.Subscriber('/gazebo/link_states', LinkStates, link_states_callback_partial)
    clock_callback_partial = partial(clock_callback, config=config)
    rospy.Subscriber('/clock', Clock, clock_callback_partial)
    pointcloud_filter_callback_partial = partial(pointcloud_filter_callback, config=config)
    pc1_sub = message_filters.Subscriber('/camera_1/depth/points', PointCloud2)
    pc2_sub = message_filters.Subscriber('/lidar_1/velodyne_points', PointCloud2)
    pc3_sub = message_filters.Subscriber('/lidar_2/velodyne_points', PointCloud2)
    pc4_sub = message_filters.Subscriber('/lidar_3/velodyne_points', PointCloud2)
    ts = message_filters.ApproximateTimeSynchronizer([pc1_sub, pc2_sub, pc3_sub, pc4_sub],
                                                    10, 0.1, allow_headerless=False)
    ts.registerCallback(pointcloud_filter_callback_partial)
    
    # Defining rate
    rate = rospy.Rate(30)
    listener = tf.TransformListener()

    # Run
    while not rospy.is_shutdown():
        if config['link_state'] is None or not config['pointcloud_msgs_to_process']:
            continue
        
        if config['clock_secs'] == cfg.bag_end:
            break
        
        occupied_voxels = occupied_voxel_cache.getElemAfterTime(config['pointcloud_msgs_timestamp'])

        # empty_voxels = empty_voxel_cache.getElemAfterTime(config['pointcloud_msgs_timestamp'])

        if type(occupied_voxels)==type(Marker()):
            config['occupied_voxel_points'] = np.array([[point.x, point.y, point.z] for point in occupied_voxels.points]) 
        elif type(occupied_voxels)==type(None):
            continue
        else:
            config['occupied_voxel_points'] = np.array(list(pointc2.read_points(occupied_voxels, skip_nans=True, field_names=("x", "y", "z")))) 

        # print(type(empty_voxels))
        # if type(empty_voxels)==type(Marker()):
        #     config['empty_voxel_points'] = np.array([[point.x, point.y, point.z] for point in empty_voxels.points]) 
        # elif type(empty_voxels)==type(MarkerArray()):
        #     config['empty_voxel_points'] = np.array([[marker.points[0].x, marker.points[0].y, marker.points[0].z] for marker in empty_voxels.markers]) 
        # else:
        #     config['empty_voxel_points'] = np.array(list(pointc2.read_points(empty_voxels, skip_nans=True, field_names=("x", "y", "z")))) 

        # Gather current link states
        filter_link_state(config['link_state'], link_names)
        for original_pointcloud in config['pointcloud_msgs_to_process'].values():
            start_time = time.time()

            # Convert the pointcloud message to a numpy array
            pointcloud = np.array(list(pointc2.read_points(original_pointcloud, skip_nans=True, field_names=("x", "y", "z"))))

            # Downsample the pointcloud to a lower resolution
            if original_pointcloud.header.frame_id == 'camera_1_depth_optical_frame':
                pointcloud = block_reduce(pointcloud, (15, 1), np.mean)

            # Transform the pointcloud to the world frame
            pointcloud_world = np.transpose(transform(world_frame, original_pointcloud.header.frame_id, listener, pointcloud))


            # For every joint, verify if points from pointcloud are inside
            for link_name, link_data in link_names.items():
                joints_data = link_data.get('joints')
                if joints_data:
                    for joint_name, joint_data in joints_data.items():
                        if joint_data.get('d_volume') is not None:
                            inside_points = link_names[link_name]['joints'][joint_name]['d_volume'].is_inside(pointcloud_world, 
                                                                                            link_names, visualize)
                            full_inside_points.append(inside_points)
                            if visualize:
                                marker_array.markers.append(link_names[link_name]['joints'][joint_name]['marker'])
            
            # Visualize 
            if visualize:
                marker_array_pub.publish(marker_array)
            # Print elapsed time
            end_time = time.time()
            elapsed_time = (end_time - start_time) * 1000
            print(f"The for loop for the frame_id {original_pointcloud.header.frame_id} took {elapsed_time:.2f} ms.")


        start_time = time.time()
        full_inside_points_array = np.unique(np.concatenate(full_inside_points, axis=0), axis=0)
        precision, recall, f1_score = verify_points_in_voxels(full_inside_points_array, config['occupied_voxel_points'], resolution, 0.01,
                                                            visualize, metrics_marker_array_pub, gt_marker_array_pub)
        # Print elapsed time
        end_time = time.time()
        elapsed_time = (end_time - start_time) * 1000
        print(f"The metric calculation took {elapsed_time:.2f} ms.")
        
        # Get the current time
        current_time = rospy.Time.now()
    
        # Append values to DataFrame
        df = pd.concat((df, pd.DataFrame({"Time": [current_time], "Precision": [precision], "Recall": [recall], "F1 Score": [f1_score]})), ignore_index=True)

        if visualize:
            full_inside_points_list = full_inside_points_array.tolist()
            marker_inside_points = create_marker_pc2(full_inside_points_list, world_frame, [0.5,0.5,0.5])
            marker_inside_pub.publish(marker_inside_points)
        # Reset variables 
        full_inside_points = []
        config['pointcloud_msgs_to_process'] = dict()

        rate.sleep()
    
    # Calculate average of each column
    averages = df.mean()

    # Write DataFrame to a CSV file
    df.to_csv(rospack.get_path('larcc_volumetric') + f'/results/{args["experiment_name"]}/full_results.csv', index=False)
    averages.to_csv(rospack.get_path('larcc_volumetric') + f'/results/{args["experiment_name"]}/averages.csv', index=False)

if __name__ == '__main__':
    main()
