import math
import os

import numpy as np
import networkx as nx

if os.environ.get('USER') == 'mike':
    from tf_transformations import quaternion_from_euler
else:
    from tf.transformations import quaternion_from_euler

from urdf_parser_py.urdf import URDF


def quaternionMatrix(quaternion):
    """Return homogeneous rotation matrix from quaternion.
    Copied from 2006, Christoph Gohlke
    """
    _EPS = np.finfo(float).eps * 4.0

    q_ = np.array(quaternion[:4], dtype=np.float64).copy()
    nq = np.dot(q_, q_)
    if nq < _EPS:
        return np.identity(4)
    q_ *= math.sqrt(2.0 / nq)
    q = np.outer(q_, q_)
    return np.array((
        (1.0-q[1, 1]-q[2, 2],     q[0, 1]-q[2, 3],     q[0, 2]+q[1, 3], 0.0),
        (q[0, 1]+q[2, 3], 1.0-q[0, 0]-q[2, 2],     q[1, 2]-q[0, 3], 0.0),
        (q[0, 2]-q[1, 3],     q[1, 2]+q[0, 3], 1.0-q[0, 0]-q[1, 1], 0.0),
        (0.0,                 0.0,                 0.0, 1.0)
    ), dtype=np.float64)

def translationQuaternionToTransform(trans, quat):
    matrix = quaternionMatrix(quat)
    matrix[0, 3] = trans[0]
    matrix[1, 3] = trans[1]
    matrix[2, 3] = trans[2]
    matrix[3, 3] = 1
    # print(str(matrix))
    return matrix

def get_transform_tree_dict(file):
    xml_robot = URDF.from_xml_file(file)
    dict = {}

    for joint in xml_robot.joints:
        child = joint.child
        parent = joint.parent
        xyz = joint.origin.xyz
        rpy = joint.origin.rpy
        key = generateKey(parent, child)

        dict[key] = {}
        dict[key]['child'] = child
        dict[key]['parent'] = parent
        dict[key]['trans'] = xyz
        dict[key]['quat'] = list(quaternion_from_euler(rpy[0], rpy[1], rpy[2], axes='sxyz'))

    return dict

def generateKey(parent, child, suffix=''):
    return parent + '-' + child + suffix

def getChain(from_frame, to_frame, transform_pool, draw=False):
    """ Gets a chain of transforms given two reference frames and a se of transformations. Computes a graph from the
    set of transforms, and then finds a path in the graph between the two given links.
    @param from_frame: initial frame
    @param to_frame: final frame
    @param transform_pool: a dictionary containing several transforms
    @return:  a chain of transforms The standard we have is to use a list of dictionaries, each containing
    # information about the transform: [{'parent': parent, 'child': child, 'key': 'parent-child'}, {...}]
    """
    chain = []  # initialized to empty list. The standard we have is to use a list of dictionaries, each containing
    # information about the transform: [{'parent': parent, 'child': child, 'key': 'parent-child'}, {...}]

    graph = nx.Graph()  # build a graph of transforms and then use it to find the path
    for transform_key, transform in transform_pool.items():  # create the graph from the transform_pool
        graph.add_edge(transform['parent'], transform['child'])

    # Debug stuff, just for drawing
    if draw == True:
        nx.draw(graph, with_labels=True)
        import matplotlib.pyplot as plt
        plt.show()

    path = nx.shortest_path(graph, from_frame, to_frame)  # compute the path between given reference frames
    for idx in range(0, len(path) - 1):  # get the chain as a list of dictionaries from the path
        parent = path[idx]
        child = path[idx + 1]
        chain.append({'parent': parent, 'child': child, 'key': generateKey(parent, child)})

    return chain


def getAggregateTransform(chain, transforms):
    """ Multiplies local transforms in a chain to get the global transform of the chain
    @param chain: a list of transforms
    @param transforms: a pool of transformations
    @return: the global transformation (4x4 homogeneous)
    """
    transform = np.eye(4, dtype=np.float)

    for link in chain:

        key = generateKey(link['parent'], link['child'])
        inverse_key = generateKey(link['child'], link['parent'])
        if key in transforms.keys():  # check if link exists in transforms
            trans = transforms[key]['trans']
            quat = transforms[key]['quat']
            # print(trans,quat)
            parent_T_child = translationQuaternionToTransform(trans, quat)
            # print(parent + '_T_' + child + ' =\n' + str(parent_T_child))
        elif inverse_key in transforms.keys():  # the reverse transform may exist
            trans = transforms[inverse_key]['trans']
            quat = transforms[inverse_key]['quat']
            parent_T_child = np.linalg.inv(translationQuaternionToTransform(trans, quat))
        else:
            raise ValueError('Transform from ' + link['parent'] + ' to ' + link['child'] + ' does not exist.')

        transform = np.dot(transform, parent_T_child)
    # print(parent + '_T_' + child + ' =\n' + str(AT))

    return transform


def getTransform(from_frame, to_frame, transforms, draw=False):
    """ Gets a transformation between any two frames
    @param from_frame: Starting frame
    @param to_frame: Ending frame
    @param transforms: dictionary of several transforms
    @return: the global transformation (4x4 homogeneous)
    """
    chain = getChain(from_frame, to_frame, transforms, draw)
    return getAggregateTransform(chain, transforms)
