#!/usr/bin/env python3
import argparse
import json
import math
import time
import rospkg
from colorama import Fore
from tensorflow import keras
import tensorflow as tf
from std_msgs.msg import String, Float64MultiArray, Float64, Bool
from DataForLearning import DataForLearning
import numpy as np
import rospy
from physical_interaction.msg import PhysicalClassification

from tabulate import tabulate
import pyfiglet


def print_tabulate(label, real_time_predictions):

    result = pyfiglet.figlet_format(label, font="space_op", width=500)

    print(Fore.LIGHTBLUE_EX + result + Fore.RESET)

    for pred in list(real_time_predictions):
        data = [['Output', pred[0], pred[1], pred[2], pred[3]]]
        print(tabulate(list(data), headers=[" ", "PULL", "PUSH", "SHAKE", "TWIST"], tablefmt="fancy_grid"))
        print("\n")


# def normalize_data(vector, measurements, data2norm):

#     sample = np.reshape(vector, (measurements, int(len(vector) / measurements)))

#     data_array_norm = np.empty((sample.shape[0], 0))

#     data_array_norm = np.hstack((data_array_norm, sample[:, 0:1] / data2norm[0]))
#     data_array_norm = np.hstack((data_array_norm, sample[:, 1:7] / data2norm[1]))
#     data_array_norm = np.hstack((data_array_norm, sample[:, 7:10] / data2norm[2]))
#     data_array_norm = np.hstack((data_array_norm, sample[:, 10:13] / data2norm[3]))
#     data_array_norm = np.reshape(data_array_norm, (1, data_array_norm.shape[0], data_array_norm.shape[1]))

#     return data_array_norm


def normalize_data(vector, measurements, train_config):

    data_array = np.reshape(vector, (measurements, int(len(vector) / measurements)))
    data_array_norm = np.empty((data_array.shape[0], 0))

    idx = 0
    for n in train_config["normalization_clusters"]:
        data_sub_array = data_array[:, idx:idx + n]
        idx += n

        data_max = abs(max(data_sub_array.min(), data_sub_array.max(), key=abs))

        data_sub_array_norm = data_sub_array / data_max
        data_array_norm = np.hstack((data_array_norm, data_sub_array_norm))
    
    vector_data_norm = np.reshape(data_array_norm, (1, vector.shape[0]))
    # data_array_norm = np.reshape(data_array_norm, (1, data_array_norm.shape[0], data_array_norm.shape[1]))

    return vector_data_norm


# def add_to_vector(data, vector, first_timestamp, list_idx):
def add_to_vector(data, vector, func_first_timestamp, dic_offset, pub_data):

    msg = Float64MultiArray()

    if func_first_timestamp is None:
        func_first_timestamp = data.timestamp()
        timestamp = 0.0
    else:
        timestamp = data.timestamp() - func_first_timestamp

    new_data = np.array([timestamp, data.joints_effort[0] - dic_offset["j0"],
                         data.joints_effort[1] - dic_offset["j1"],
                         data.joints_effort[2] - dic_offset["j2"],
                         data.joints_effort[3] - dic_offset["j3"],
                         data.joints_effort[4] - dic_offset["j4"],
                         data.joints_effort[5] - dic_offset["j5"],
                         data.wrench_force_torque.force.x - dic_offset["fx"],
                         data.wrench_force_torque.force.y - dic_offset["fy"],
                         data.wrench_force_torque.force.z - dic_offset["fz"],
                         data.wrench_force_torque.torque.x - dic_offset["mx"],
                         data.wrench_force_torque.torque.y - dic_offset["my"],
                         data.wrench_force_torque.torque.z - - dic_offset["mz"]])

    msg.data = new_data
    pub_data.publish(msg)

    return np.append(vector, new_data), func_first_timestamp


def calc_data_mean(data):
    values = np.array([data.wrench_force_torque.force.z/10, data.wrench_force_torque.torque.x,
                       data.wrench_force_torque.torque.y, data.wrench_force_torque.torque.z])

    mean_value = np.mean(values)
    return mean_value


def get_statistics(data_list):
    data_list_mean = np.mean(np.array(data_list))

    summ = 0
    for x in data_list:
        summ += (x-data_list_mean)**2

    data_list_var = math.sqrt(summ/len(data_list))
    return data_list_mean, data_list_var


def offset_calculation(dic):

    dic_offset_mean = {}

    for key in dic:

        dic_offset_mean[key] = np.mean(dic[key])

    return dic_offset_mean


if __name__ == '__main__':

    rospack = rospkg.RosPack()

    package_path = rospack.get_path('physical_interaction')
    # ---------------------------------------------------------------------------------------------
    # --------------------------------------INPUT VARIABLES----------------------------------------
    # ---------------------------------------------------------------------------------------------

    # f = open('../config/config.json')
    f = open(f'{package_path}/config/config.json')
    config = json.load(f)
    f.close()

    desired_pose = [config["initial_pose"][2], config["initial_pose"][1], config["initial_pose"][0],
                    config["initial_pose"][3], config["initial_pose"][4], config["initial_pose"][5]]

    f = open(f'{package_path}/config/clusters_max_min.json')
    clusters_max_min = json.load(f)
    f.close()

    data_max_timestamp = abs(max(clusters_max_min["timestamp"]["max"], clusters_max_min["timestamp"]["min"], key=abs))
    data_max_joints = abs(max(clusters_max_min["joints"]["max"], clusters_max_min["joints"]["min"], key=abs))
    data_max_gripper_F = abs(max(clusters_max_min["gripper_F"]["max"], clusters_max_min["gripper_F"]["min"], key=abs))
    data_max_gripper_M = abs(max(clusters_max_min["gripper_M"]["max"], clusters_max_min["gripper_M"]["min"], key=abs))
    data2norm = [data_max_timestamp, data_max_joints, data_max_gripper_F, data_max_gripper_M]
    # model = keras.models.load_model("../nn_models/feedforward_model")
    # model = keras.models.load_model("../nn_models/cnn3_50ms")
    # model = keras.models.load_model("../nn_models/cnn4_model_20ms")
    model = keras.models.load_model(f"{package_path}/nn_models/myModel2")

    # ---------------------------------------------------------------------------------------------
    # -------------------------------INITIATE COMMUNICATION----------------------------------------
    # ---------------------------------------------------------------------------------------------

    rospy.init_node("physical_classification_old", anonymous=False)

    # For force/torque GUI
    pub_vector = rospy.Publisher("learning_data", Float64MultiArray, queue_size=10)
    pub_class = rospy.Publisher("classification", PhysicalClassification, queue_size=10)

    data_for_learning = DataForLearning()

    rate = rospy.Rate(config["rate"])

    time.sleep(0.2) # Waiting time to ros nodes properly initiate

    # ---------------------------------------------------------------------------------------------
    # -------------------------------INITIATE ROBOT------------------------------------------------
    # ---------------------------------------------------------------------------------------------
    #

    list_calibration = []
    dic_offset_calibration = {"fx": [], "fy": [], "fz": [], "mx": [],
                              "my": [], "mz": [], "j0": [], "j1": [],
                              "j2": [], "j3": [], "j4": [], "j5": []}
    dic_variable_offset = None

    limit = int(config["time"] * config["rate"])

    trainning_data_array = np.empty((0, limit * len(config["data"])))

    sequential_actions = False
    first_time_stamp_show = None
    vector_data_show = np.empty((0, 0))
    rest_state_mean = 0
    predicted_data_saved = np.empty((0, 651))
    predictions_saved = np.empty((0, 4))

    while not rospy.is_shutdown(): # This is the data acquisition cycle
        if math.dist(data_for_learning.joints_position, desired_pose) < 0.001:
            if not sequential_actions:
                st = time.time()
                while not rospy.is_shutdown(): # This is the calibration cycle
                    print("Calculating rest state variables...")
                    list_calibration = []
                    dic_offset_calibration = {"fx": [], "fy": [], "fz": [], "mx": [],
                                              "my": [], "mz": [], "j0": [], "j1": [],
                                              "j2": [], "j3": [], "j4": [], "j5": []}

                    class_msg = PhysicalClassification()
                    class_msg.header = data_for_learning.header
                    class_msg.classification = "Calibrating"
                    pub_class.publish(class_msg)

                    for i in range(0, 49):
                        list_calibration.append(calc_data_mean(data_for_learning))
                        if dic_variable_offset is not None:
                            add_to_vector(data_for_learning, vector_data_show, None, dic_variable_offset, pub_vector)

                        dic_offset_calibration["fx"].append(data_for_learning.wrench_force_torque.force.x)
                        dic_offset_calibration["fy"].append(data_for_learning.wrench_force_torque.force.y)
                        dic_offset_calibration["fz"].append(data_for_learning.wrench_force_torque.force.z)
                        dic_offset_calibration["mx"].append(data_for_learning.wrench_force_torque.torque.x)
                        dic_offset_calibration["my"].append(data_for_learning.wrench_force_torque.torque.y)
                        dic_offset_calibration["mz"].append(data_for_learning.wrench_force_torque.torque.z)

                        dic_offset_calibration["j0"].append(data_for_learning.joints_effort[0])
                        dic_offset_calibration["j1"].append(data_for_learning.joints_effort[1])
                        dic_offset_calibration["j2"].append(data_for_learning.joints_effort[2])
                        dic_offset_calibration["j3"].append(data_for_learning.joints_effort[3])
                        dic_offset_calibration["j4"].append(data_for_learning.joints_effort[4])
                        dic_offset_calibration["j5"].append(data_for_learning.joints_effort[5])

                        time.sleep(0.005)

                    class_msg = PhysicalClassification()
                    class_msg.header = data_for_learning.header
                    class_msg.classification = "None"
                    pub_class.publish(class_msg)

                    rest_state_mean, rest_state_var = get_statistics(list_calibration)
                    dic_variable_offset = offset_calculation(dic_offset_calibration)
                    print(rest_state_mean)
                    print(rest_state_var)

                    if rest_state_var < 0.03:
                        break
                print("Calibration time: " + str(time.time() - st))
                print(f"Waiting for action to initiate prediction ...")

                while not rospy.is_shutdown(): # This cycle waits for the external force to start storing data
                    data_mean = calc_data_mean(data_for_learning)
                    variance = data_mean - rest_state_mean
                    add_to_vector(data_for_learning, vector_data_show, None, dic_variable_offset, pub_vector)
                    if abs(variance) > config["force_threshold_start"]:
                        break

                    time.sleep(0.1)
                time.sleep(config["waiting_offset"]) # time waiting to initiate the experiment

            # ---------------------------------------------------------------------------------------------
            # -------------------------------------GET DATA------------------------------------------------
            # ---------------------------------------------------------------------------------------------

            end_experiment = False
            first_time_stamp = None
            vector_data = np.empty((0, 0))

            i = 0
            treshold_counter = 0

            rate.sleep()  # The first time rate sleep was used it was giving problems (would not wait the right amout of time)

            try:
                while not rospy.is_shutdown() and i < limit: # This cycle stores data for a fixed amount of time
                    i += 1
                    vector_data, first_time_stamp = add_to_vector(data_for_learning,
                                                                  vector_data, first_time_stamp, dic_variable_offset, pub_vector)
                    data_mean = calc_data_mean(data_for_learning)
                    variance = data_mean - rest_state_mean
                    # // old values th: 0.2 and 0.075
                    if abs(variance) < config["force_threshold_end"]:
                        treshold_counter += 1
                        if treshold_counter >= config["threshold_counter_limit"]:
                            end_experiment = True
                            break
                    else:
                        treshold_counter = 0

                    rate.sleep()
            except:
                print("ctrl+C pressed1")

            try:
                if end_experiment:
                    sequential_actions = False
                    print("\nNot enough for prediction\n")
                    class_msg = PhysicalClassification()
                    class_msg.header = data_for_learning.header
                    class_msg.classification = "None"
                    pub_class.publish(class_msg)
                else:
                    sequential_actions = True
                    # data_norm = normalize_data(vector_data, limit, data2norm)
                    data_norm = normalize_data(vector_data, limit, config)
                    predictions = model.predict(x=data_norm, verbose=2)
                    labels = config["action_classes"]
                    max_idx = np.argmax(list(predictions))
                    print(max_idx)
                    print(predictions[0][int(max_idx)])
                    print(predictions)
                    predicted_label = labels[int(max_idx)]

                    vector_data = np.append(vector_data, max_idx)
                    predicted_data_saved = np.append(predicted_data_saved, [vector_data], axis=0)
                    predictions_saved = np.append(predictions_saved, predictions, axis=0)
                    # pub_class.publish(predicted_label + " " + str(round(float(predictions[0][int(max_idx)] * 100), 2)) + "%")

                    class_msg = PhysicalClassification()
                    class_msg.header = data_for_learning.header
                    class_msg.classification = predicted_label
                    pub_class.publish(class_msg)

                    print("-----------------------------------------------------------")
                    print_tabulate(predicted_label, predictions)
                    print("-----------------------------------------------------------")
            except:
                print("ctrl+C pressed2")
        else:
            class_msg = PhysicalClassification()
            class_msg.header = data_for_learning.header
            class_msg.classification = "None"
            pub_class.publish(class_msg)

    del data_for_learning
