#!/usr/bin/env python3

import json
import math
import time

import rospkg
from colorama import Fore
import keras
from DataForLearning import DataForLearning
import numpy as np
from std_msgs.msg import String
from geometry_msgs.msg import Wrench, Pose, WrenchStamped
from physical_interaction.msg import PhysicalClassification
from std_msgs.msg import String, Float64MultiArray, Float64, Bool

import rospy
from tabulate import tabulate
import pyfiglet

wrench_force_torque = Wrench()
wrench_pose = Pose()
joints_position = []
joints_effort = []


def print_tabulate(label, real_time_predictions):
    result = pyfiglet.figlet_format(label, font="space_op", width=500)

    print(Fore.LIGHTBLUE_EX + result + Fore.RESET)

    for pred in list(real_time_predictions):
        data = [['Output', pred[0], pred[1], pred[2], pred[3]]]
        print(tabulate(list(data), headers=[" ", "PULL", "PUSH", "SHAKE", "TWIST"], tablefmt="fancy_grid"))
        print("\n")


def add_to_vector(data, vector, func_first_timestamp, dic_offset):
    if func_first_timestamp is None:
        func_first_timestamp = data.timestamp()
        timestamp = 0.0
    else:
        timestamp = data.timestamp() - func_first_timestamp

    new_data = np.array([timestamp, data.joints_effort[0] - dic_offset["j0"],
                         data.joints_effort[1] - dic_offset["j1"],
                         data.joints_effort[2] - dic_offset["j2"],
                         data.joints_effort[3] - dic_offset["j3"],
                         data.joints_effort[4] - dic_offset["j4"],
                         data.joints_effort[5] - dic_offset["j5"],
                         data.wrench_force_torque.force.x - dic_offset["fx"],
                         data.wrench_force_torque.force.y - dic_offset["fy"],
                         data.wrench_force_torque.force.z - dic_offset["fz"],
                         data.wrench_force_torque.torque.x - dic_offset["mx"],
                         data.wrench_force_torque.torque.y - dic_offset["my"],
                         data.wrench_force_torque.torque.z - - dic_offset["mz"]])

    return np.append(vector, new_data), func_first_timestamp


def normalize_data(vector, measurements, data2norm):
    sample = np.reshape(vector, (measurements, int(len(vector) / measurements)))

    data_array_norm = np.empty((sample.shape[0], 0))

    data_array_norm = np.hstack((data_array_norm, sample[:, 0:1] / data2norm[0]))
    data_array_norm = np.hstack((data_array_norm, sample[:, 1:7] / data2norm[1]))
    data_array_norm = np.hstack((data_array_norm, sample[:, 7:10] / data2norm[2]))
    data_array_norm = np.hstack((data_array_norm, sample[:, 10:13] / data2norm[3]))
    data_array_norm = np.reshape(data_array_norm, (1, data_array_norm.shape[0], data_array_norm.shape[1]))

    return data_array_norm


def calc_data_mean(data):
    values = np.array([data.wrench_force_torque.force.z / 10, data.wrench_force_torque.torque.x,
                       data.wrench_force_torque.torque.y, data.wrench_force_torque.torque.z])

    mean_value = np.mean(values)

    return mean_value


def get_statistics(data_list):
    data_list_mean = np.mean(np.array(data_list))

    summ = 0
    for x in data_list:
        summ += (x - data_list_mean) ** 2

    data_list_var = math.sqrt(summ / len(data_list))
    return data_list_mean, data_list_var


def offset_calculation(dic):
    dic_offset_mean = {}

    for key in dic:
        dic_offset_mean[key] = np.mean(dic[key])

    return dic_offset_mean


if __name__ == '__main__':

    rospack = rospkg.RosPack()

    package_path = rospack.get_path('physical_interaction')
    labels = ["pull", "push", "shake", "twist"]

    rospy.init_node("physical_classification", anonymous=True)
    pub_vector = rospy.Publisher("learning_data", Float64MultiArray, queue_size=10)
    pub_class = rospy.Publisher("classification", PhysicalClassification, queue_size=10)

    # tmp_data = np.load('/tmp/test_data.npy', mmap_mode=None, allow_pickle=False, fix_imports=True, encoding='ASCII')

    data_for_learning = DataForLearning()

    f = open(f'{package_path}/config/config.json')
    config = json.load(f)
    f.close()

    desired_pose = [config["initial_pose"][2], config["initial_pose"][1], config["initial_pose"][0],
                    config["initial_pose"][3], config["initial_pose"][4], config["initial_pose"][5]]

    f = open(f'{package_path}/config/clusters_max_min.json')
    clusters_max_min = json.load(f)
    f.close()

    data_max_timestamp = abs(max(clusters_max_min["timestamp"]["max"], clusters_max_min["timestamp"]["min"], key=abs))
    data_max_joints = abs(max(clusters_max_min["joints"]["max"], clusters_max_min["joints"]["min"], key=abs))
    data_max_gripper_F = abs(max(clusters_max_min["gripper_F"]["max"], clusters_max_min["gripper_F"]["min"], key=abs))
    data_max_gripper_M = abs(max(clusters_max_min["gripper_M"]["max"], clusters_max_min["gripper_M"]["min"], key=abs))
    data2norm = [data_max_timestamp, data_max_joints, data_max_gripper_F, data_max_gripper_M]

    # f = open(f'{package_path}/nn_models/cnn4_model_20ms')
    # model = keras.models.load_model(f'{package_path}/nn_models/cnn4_model_20ms')
    # model = keras.models.load_model(f"{package_path}/nn_models/cnn3_50ms")
    model = keras.models.load_model(f"{package_path}/nn_models/demo_model")

    rate = rospy.Rate(100)
    st = time.time()
    while not rospy.is_shutdown():  # This is the data acquisition cycle
        end_experiment = False
        dic_variable_offset = None
        rest_state_mean = 0
        vector_data_show = np.empty((0, 0))

        if math.dist(data_for_learning.joints_position, desired_pose) < 0.001:

            while not rospy.is_shutdown():  # This is the calibration cycle
                print("Calculating rest state variables...")
                list_calibration = []
                dic_offset_calibration = {"fx": [], "fy": [], "fz": [], "mx": [],
                                          "my": [], "mz": [], "j0": [], "j1": [],
                                          "j2": [], "j3": [], "j4": [], "j5": []}
                class_msg = PhysicalClassification()
                class_msg.header = data_for_learning.header
                class_msg.classification = "Calibrating"
                pub_class.publish(class_msg)

                for i in range(0, 49):
                    list_calibration.append(calc_data_mean(data_for_learning))
                    if dic_variable_offset is not None:
                        add_to_vector(data_for_learning, vector_data_show, None, dic_variable_offset)

                    dic_offset_calibration["fx"].append(data_for_learning.wrench_force_torque.force.x)
                    dic_offset_calibration["fy"].append(data_for_learning.wrench_force_torque.force.y)
                    dic_offset_calibration["fz"].append(data_for_learning.wrench_force_torque.force.z)
                    dic_offset_calibration["mx"].append(data_for_learning.wrench_force_torque.torque.x)
                    dic_offset_calibration["my"].append(data_for_learning.wrench_force_torque.torque.y)
                    dic_offset_calibration["mz"].append(data_for_learning.wrench_force_torque.torque.z)

                    dic_offset_calibration["j0"].append(data_for_learning.joints_effort[0])
                    dic_offset_calibration["j1"].append(data_for_learning.joints_effort[1])
                    dic_offset_calibration["j2"].append(data_for_learning.joints_effort[2])
                    dic_offset_calibration["j3"].append(data_for_learning.joints_effort[3])
                    dic_offset_calibration["j4"].append(data_for_learning.joints_effort[4])
                    dic_offset_calibration["j5"].append(data_for_learning.joints_effort[5])

                    time.sleep(0.005)
                class_msg = PhysicalClassification()
                class_msg.header = data_for_learning.header
                class_msg.classification = "None"
                pub_class.publish(class_msg)
                rest_state_mean, rest_state_var = get_statistics(list_calibration)
                dic_variable_offset = offset_calculation(dic_offset_calibration)
                print(rest_state_mean)
                print(rest_state_var)

                if rest_state_var < 0.03:
                    break

            print("\nCalibration time: " + str(time.time() - st))
            print(f"Waiting for action to initiate prediction ...")

            while not rospy.is_shutdown():  # This cycle waits for the external force to start storing data
                data_mean = calc_data_mean(data_for_learning)
                variance = data_mean - rest_state_mean

                add_to_vector(data_for_learning, vector_data_show, None, dic_variable_offset)

                if abs(variance) > 0.2:
                    break

                time.sleep(0.1)

            i = 0
            first_time_stamp = None
            limit = 20
            vector_data = np.empty((0, 0))
            try:
                treshold_counter = 0
                while not rospy.is_shutdown() and i < limit:  # This cycle stores data for a fixed amount of time
                    i += 1
                    vector_data, first_time_stamp = add_to_vector(data_for_learning, vector_data, first_time_stamp,
                                                                  dic_variable_offset)

                    data_mean = calc_data_mean(data_for_learning)
                    variance = data_mean - rest_state_mean

                    # print("\nabs(variance)")
                    # print(abs(variance))

                    if abs(variance) < 0.2:
                        treshold_counter += 1
                        if treshold_counter >= 4:

                            end_experiment = True
                            break
                    else:
                        treshold_counter = 0

                    rate.sleep()
            except:
                print("ctrl+C pressed1")

            if end_experiment:
                print("\nNot enough for prediction\n")
                class_msg = PhysicalClassification()
                class_msg.header = data_for_learning.header
                class_msg.classification = "None"
                pub_class.publish(class_msg)
            else:
                vector_data = np.reshape(vector_data, (limit, int(len(vector_data) / limit)))
                vector_data = np.reshape(vector_data, (1, vector_data.shape[0], vector_data.shape[1]))

                predictions = model.predict(x=vector_data, verbose=2)
                max_idx = np.argmax(list(predictions))
                predicted_label = labels[int(max_idx)]

                class_msg = PhysicalClassification()
                class_msg.header = data_for_learning.header
                class_msg.classification = predicted_label
                pub_class.publish(class_msg)

                print("-----------------------------------------------------------")
                print_tabulate(predicted_label, predictions)
                print("-----------------------------------------------------------")
                time.sleep(2)

        else:
            class_msg = PhysicalClassification()
            class_msg.header = data_for_learning.header
            class_msg.classification = "None"
            pub_class.publish(class_msg)

