#!/usr/bin/env python3

import numpy as np
import yaml
import argparse
import sys
import rospy

# 3rd-party
from scipy.spatial.transform import Rotation as R
import atom_msgs.srv
from tf.listener import TransformListener

# local packages
from atom_core.config_io import loadConfig
from atom_core.ros_utils import filterLaunchArguments
from atom_core.transformations import compareTransforms


def convertToMatrix(t):
    matrix = np.zeros((4, 4))
    matrix[:3, :3] = R.from_quat(t[1]).as_matrix()
    matrix[:3, 3] = t[0]
    matrix[3, 3] = 1
    return matrix


def compareTransformsWithThresholds(t1, t2, translation_threshold, rotation_threshold):
    """
    Compare two 3D transforms composed of translation and rotation components.

    Parameters:
    - t1 (list): First transform list containing a 3-element translation vector and a 4-element quaternion.
    - t2 (list): Second transform list with the same structure as t1.
    - translation_threshold (float): Maximum allowable Euclidean distance between translation vectors.
    - rotation_threshold (float): Maximum allowable difference between rotation matrices.

    Returns:
    - bool: True if the translation and rotation differences are below the specified thresholds, False otherwise.
    """

    # Converting the transformation to the 4x4 matrix
    t1_matrix = convertToMatrix(t1)
    t2_matrix = convertToMatrix(t2)

    # Extract translation and rotation components
    translation_error, rotation_error, _, _, _, _, _, _  = compareTransforms(t1_matrix, t2_matrix)

    # Check if differences are below the specified thresholds
    return translation_error < translation_threshold and rotation_error < rotation_threshold


class AutomaticDataCollector:

    def __init__(self, config, rate=10, time_threshold=4.0, translation_threshold=0.1, rotation_threshold=0.1):
        """
        Initialize the AutomaticDataCollector.

        Parameters:
        - config (dict): The ATOM config.
        - rate (float): ROS Rate for the data collection loop.
        - time_threshold (float): Maximum time difference for considering frames as static.
        - translation_threshold (float): Maximum allowable Euclidean distance between translation vectors.
        - rotation_threshold (float): Maximum allowable difference between rotation matrices, measured using Frobenius norm.
        """
        self.world_link = config['world_link']
        self.rate = rospy.Rate(rate) #hz
        self.time_threshold = time_threshold
        self.listener = TransformListener()
        self.automatic_data_collector_on = False
        self.frame_times = {}
        self.translation_threshold = translation_threshold
        self.rotation_threshold = rotation_threshold
        self.robot_static = False
        self.collection_taken = False


        # Add service to toggle the automatic data collector
        self.service_get_dataset = rospy.Service('~toggle_automatic_data_collector',
                                                 atom_msgs.srv.ToggleAutomaticDataCollector,
                                                 self.callbackToggleAutomaticDataCollector)

    def callbackToggleAutomaticDataCollector(self, request):
        """
        Callback function for toggling the automatic data collector service.

        Parameters:
        - request: Service request.

        Returns:
        - atom_msgs.srv.ToggleAutomaticDataCollectorResponse: Service response.
        """
        print('callbackAutomaticDataCollector service called')
        self.automatic_data_collector_on = not self.automatic_data_collector_on

        response = atom_msgs.srv.ToggleAutomaticDataCollectorResponse()
        response.success = True
        response.message = f"Automatic Data Collector on: {self.automatic_data_collector_on}"
        return response
    
    def run(self):
        """
        Main loop for automatic data collection and checking if the robot is static.
        """
        frames_dict = {}
        while not rospy.is_shutdown():
            if not self.automatic_data_collector_on:
                continue
            
            # Load full frames dictionary only once
            if frames_dict == {}:
                frames_dict = yaml.safe_load(self.listener._buffer.all_frames_as_yaml())

            for frame in frames_dict.keys():
                t = self.listener.getLatestCommonTime(frame, self.world_link)
                (trans, quat) = self.listener.lookupTransform(frame, self.world_link, t)
                if frame not in self.frame_times.keys():
                    self.frame_times[frame]={'times': [t.to_sec()],
                                            'values': [[trans, quat]],
                                            'static': False}
                elif t.to_sec() not in self.frame_times[frame]['times']:
                    self.frame_times[frame]['times'].append(t.to_sec())
                    self.frame_times[frame]['values'].append([trans, quat])

                    while self.frame_times[frame]['times'][-1] - self.frame_times[frame]['times'][0] > self.time_threshold:
                        self.frame_times[frame]['static'] = compareTransformsWithThresholds(self.frame_times[frame]['values'][0], \
                            self.frame_times[frame]['values'][-1], self.translation_threshold, self.rotation_threshold)
                            
                        
                        self.frame_times[frame]['times'].pop(0)
                        self.frame_times[frame]['values'].pop(0)
                elif t.to_sec() == rospy.Time(0).to_sec():
                    self.frame_times[frame]['static'] = True

                
            # Check if all frames are static
            static_frames = [self.frame_times[frame]['static'] for frame in self.frame_times if 'static' in self.frame_times[frame]]
            self.robot_static = all(static_frames)

            # If the robot is static and collection hasn't been taken yet, attempt to save the collection
            if self.robot_static and not self.collection_taken:
                rospy.wait_for_service('/collect_data/save_collection')
                save_collection = rospy.ServiceProxy('/collect_data/save_collection', atom_msgs.srv.SaveCollection)
                response = save_collection()
                self.collection_taken = response.success
            elif not self.robot_static:
                self.collection_taken = False

            self.rate.sleep()


    

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("-c", "--calibration_file", help='full path to calibration file.', type=str, required=True)
    args = vars(ap.parse_args(args=filterLaunchArguments(sys.argv)))
    config = loadConfig(args['calibration_file'])

    # Initialize ROS 
    rospy.init_node("automatic_data_collector")

    automatic_data_collector = AutomaticDataCollector(config)
    automatic_data_collector.run()

if __name__ == '__main__':
    main()