#!/usr/bin/env python3
"""
Casts an optimization problem using an ATOM dataset file as input. Then calibrates by running the optimization.
"""
import copy
import json
import os, math, signal, sys
import argparse, random
from functools import partial
import numpy as np
import ros_numpy

# 3rd-party
import rospy, tf
from colorama import Fore, Style, Back
from urdf_parser_py.urdf import URDF

# own packages
import OptimizationUtils.OptimizationUtils as OptimizationUtils
import atom_core.naming
from atom_calibration.calibration.getters_and_setters import getterTransform, setterTransform, getterCameraIntrinsics, \
    setterCameraIntrinsics

from atom_core.naming import generateKey, generateName
from atom_core.utilities import waitForKeyPress, waitForKeyPress2
from atom_calibration.calibration.objective_function import getPointsInSensorAsNPArray, objectiveFunction
import atom_calibration.calibration.patterns_config as patterns
from atom_calibration.calibration.visualization import setupVisualization, visualizationFunction
from atom_core.config_io import readXacroFile, uriReader
from atom_core.dataset_io import loadResultsJSON, saveResultsJSON, \
    getPointCloudMessageFromDictionary, filterCollectionsFromDataset, filterSensorsFromDataset, \
    addNoiseToInitialGuess


# -------------------------------------------------------------------------------
# --- FUNCTIONS
# -------------------------------------------------------------------------------
def signal_handler(sig, frame):
    print('Stopping optimization (Ctrl+C pressed)')
    sys.exit(0)


# -------------------------------------------------------------------------------
# --- MAIN
# -------------------------------------------------------------------------------
def main():
    # ---------------------------------------
    # --- Parse command line argument
    # ---------------------------------------
    signal.signal(signal.SIGINT, signal_handler)
    # print('Press Ctrl+C')
    # signal.pause()

    ap = argparse.ArgumentParser()
    ap = OptimizationUtils.addArguments(ap)  # OptimizationUtils arguments
    ap.add_argument("-json", "--json_file", help="Json file containing input dataset.", type=str, required=True)
    ap.add_argument("-v", "--verbose", help="Be verbose", action='store_true', default=False)
    ap.add_argument("-rv", "--ros_visualization", help="Publish ros visualization markers.", action='store_true')
    ap.add_argument("-si", "--show_images", help="shows images for each camera", action='store_true', default=False)
    ap.add_argument("-oi", "--optimize_intrinsics", help="Adds camera instrinsics to the optimization",
                    action='store_true', default=False)
    ap.add_argument("-pof", "--profile_objective_function",
                    help="Runs and prints a profile of the objective function, then exits.",
                    action='store_true', default=False)
    ap.add_argument("-sr", "--sample_residuals", help="Samples residuals", type=float, default=1)
    ap.add_argument("-ss", "--sample_seed", help="Sampling seed", type=int)
    ap.add_argument("-ajf", "--all_joints_fixed",
                    help="Assume all joints are fixed and because of that draw a single robot mesh. Overrides "
                         "automatic detection of static robot.",
                    action='store_true', default=False)
    ap.add_argument("-uic", "--use_incomplete_collections",
                    help="Remove any collection which does not have a detection for all sensors.",
                    action='store_true', default=False)
    ap.add_argument("-rpd", "--remove_partial_detections",
                    help="Remove detected labels which are only partial. Used or the Charuco.",
                    action='store_true', default=False)
    ap.add_argument("-nig", "--noisy_initial_guess", nargs=2, metavar=('translation', 'rotation'),
                    help="Percentage of noise to add to the initial guess atomic transformations set before.",
                    type=float, default=[0.0, 0.0]),
    ap.add_argument("-ssf", "--sensor_selection_function", default=None, type=lambda s: eval(s, globals()),
                    help='A string to be evaluated into a lambda function that receives a sensor name as input and '
                         'returns True or False to indicate if the sensor should be loaded (and used in the '
                         'optimization). The Syntax is lambda name: f(x), where f(x) is the function in python '
                         'language. Example: lambda name: name in ["left_laser", "frontal_camera"] , to load only '
                         'sensors left_laser and frontal_camera')
    ap.add_argument("-csf", "--collection_selection_function", default=None, type=lambda s: eval(s, globals()),
                    help='A string to be evaluated into a lambda function that receives a collection name as input and '
                         'returns True or False to indicate if the collection should be loaded (and used in the '
                         'optimization). The Syntax is lambda name: f(x), where f(x) is the function in python '
                         'language. Example: lambda name: int(name) > 5 , to load only collections 6, 7, and onward.')
    ap.add_argument("-phased", "--phased_execution", help="Stay in a loop before calling optimization, and in another "
                                                          "after calling the optimization. Good for debugging.",
                    action='store_true', default=False)
    ap.add_argument("-ipg", "--initial_pose_ghost", help="Draw a ghost mesh with the systems initial pose. Good for debugging.",
                    action='store_true', default=False)
    ap.add_argument('-oj', '--output_json', help='Full path to output json file.', type=str, required=False, default=None)
    ap.add_argument('-ox', '--output_xacro', help='Full path to output xacro file.', type=str, required=False, default=None)

    # Roslaunch adds two arguments (__name and __log) that break our parser. Lets remove those.
    arglist = [x for x in sys.argv[1:] if not x.startswith('__')]
    args = vars(ap.parse_args(args=arglist))

    # ---------------------------------------
    # --- INITIALIZATION Read data from file
    # ---------------------------------------
    # Loads a json file containing the detections. Returned json_file has path resolved by urireader.
    dataset, json_file = loadResultsJSON(args['json_file'], args['collection_selection_function'])

    # ---------------------------------------
    # --- Filter some collections and / or sensors from the dataset
    # ---------------------------------------
    dataset = filterCollectionsFromDataset(dataset, args)  # filter collections



    # Create the chessboard dataset must be called before deleting the sensors to cope with the possibility of
    # setting up an optimization without cameras. For now we MUST have a camera to estimate the initial parameters
    # related to the pattern pose (we use solve PNP for a camera).
    dataset['patterns'] = patterns.createPatternLabels(args, dataset)  # TODO: Solve this strange dependency.

    dataset = filterSensorsFromDataset(dataset, args)  # filter sensors

    print('Loaded dataset containing ' + str(len(dataset['sensors'].keys())) + ' sensors and ' + str(
        len(dataset['collections'].keys())) + ' collections.')


    # ---------------------------------------
    # --- Store initial values for transformations to be optimized
    # ---------------------------------------
    for collection_key, collection in dataset['collections'].items():
        initial_transform_key = generateName('transforms', suffix='ini')
        collection[initial_transform_key] = copy.deepcopy(collection['transforms'])
        for transform_key, transform in collection[initial_transform_key].items():
            transform['parent'] = generateName(transform['parent'], suffix='ini')
            transform['child'] = generateName(transform['child'], suffix='ini')

    # ---------------------------------------
    # --- Define selected collection key.
    # ---------------------------------------
    # For the getters we only need to get one collection. Lets take the first key on the dictionary and always get that
    # transformation.
    selected_collection_key = list(dataset['collections'].keys())[0]
    print('Selected collection key is ' + str(selected_collection_key))

    # ---------------------------------------
    # --- Add noise to the initial guess atomic transformations to be calibrated.
    # ---------------------------------------
    addNoiseToInitialGuess(dataset, args, selected_collection_key)

    # ---------------------------------------
    # --- DETECT EDGES IN THE LASER SCANS
    # ---------------------------------------
    # TODO this should go to the collect and label
    for sensor_key, sensor in dataset['sensors'].items():
        if sensor['msg_type'] == 'LaserScan':  # only for lasers
            for collection_key, collection in dataset['collections'].items():
                idxs = collection['labels'][sensor_key]['idxs']
                edges = []  # a list of edges
                for i in range(0, len(idxs) - 1):
                    if (idxs[i + 1] - idxs[i]) != 1:
                        edges.append(i)
                        edges.append(i + 1)

                # Remove first (right most) and last (left most) edges, since these are often false edges.
                if len(edges) > 0:
                    edges.pop(0)  # remove the first element.
                if len(edges) > 0:
                    edges.pop()  # if the index is not given, then the last element is popped out and removed.
                collection['labels'][sensor_key]['edge_idxs'] = edges


    # ---------------------------------------
    # --- SETUP OPTIMIZER: Create data models
    # ---------------------------------------
    opt = OptimizationUtils.Optimizer()
    opt.addDataModel('args', args)
    opt.addDataModel('dataset', dataset)



    # ---------------------------------------
    # --- DEFINE THE VISUALIZATION FUNCTION
    # ---------------------------------------
    if args['view_optimization']:
        opt.setInternalVisualization(True)
    else:
        opt.setInternalVisualization(False)

    if args['ros_visualization']:
        print("Configuring visualization ... ")
        graphics = setupVisualization(dataset, args, selected_collection_key)
        opt.addDataModel('graphics', graphics)

        opt.setVisualizationFunction(visualizationFunction, args['ros_visualization'], niterations=1, figures=[])

    # ---------------------------------------
    # --- SETUP OPTIMIZER: Add sensor parameters
    # ---------------------------------------
    # Each sensor will have a position (tx,ty,tz) and a rotation (r1,r2,r3)

    # Add parameters related to the sensors
    translation_delta = 0.2
    # TODO temporary placement of top_left_camera
    # dataset['calibration_config']['anchored_sensor'] = 'world_camera'
    print('Anchored sensor is ' + Fore.GREEN + dataset['calibration_config'][
        'anchored_sensor'] + Style.RESET_ALL)

    anchored_sensor = dataset['calibration_config']['anchored_sensor']
    if anchored_sensor in dataset['sensors']:
        anchored_parent = dataset['sensors'][anchored_sensor]['calibration_parent']
        anchored_child = dataset['sensors'][anchored_sensor]['calibration_child']
        anchored_transform_key = atom_core.naming.generateKey(anchored_parent, anchored_child)
    else:
        anchored_transform_key = ''  # not transform is anchored
    # TODO If we want and anchored sensor we should search (and fix) all the transforms in its chain that do are
    #  being optimized

    print('Creating parameters ...')
    # Steaming from the config json, we define a transform to be optimized for each sensor. It could happen that two
    # or more sensors define the same transform to be optimized (#120). To cope with this we first create a list of
    # transformations to be optimized and then compute the unique set of that list.
    transforms_set = set()
    for sensor_key, sensor in dataset['sensors'].items():
        transform_key = atom_core.naming.generateKey(sensor['calibration_parent'], sensor['calibration_child'])
        transforms_set.add(transform_key)

    for transform_key in transforms_set:  # push six parameters for each transform to be optimized.
        initial_transform = getterTransform(dataset, transform_key=transform_key,
                                            collection_name=selected_collection_key)

        if transform_key == anchored_transform_key:
            bound_max = [x + sys.float_info.epsilon for x in initial_transform]
            bound_min = [x - sys.float_info.epsilon for x in initial_transform]
        else:
            bound_max = [+np.inf for x in initial_transform]
            bound_min = [-np.inf for x in initial_transform]

        opt.pushParamVector(group_name=transform_key, data_key='dataset',
                            getter=partial(getterTransform, transform_key=transform_key,
                                           collection_name=selected_collection_key),
                            setter=partial(setterTransform, transform_key=transform_key,
                                           collection_name=None),
                            suffix=['_x', '_y', '_z', '_r1', '_r2', '_r3'], bound_max=bound_max, bound_min=bound_min)

    # Intrinsics
    # TODO bound_min and max for intrinsics
    if args['optimize_intrinsics']:
        for sensor_key, sensor in dataset['sensors'].items():
            if sensor['msg_type'] == 'Image':  # if sensor is a camera add intrinsics
                opt.pushParamVector(group_name=str(sensor_key) + '_intrinsics', data_key='dataset',
                                    getter=partial(getterCameraIntrinsics, sensor_key=sensor_key),
                                    setter=partial(setterCameraIntrinsics, sensor_key=sensor_key),
                                    suffix=['_fx', '_fy', '_cx', '_cy', '_k1', '_k2', '_t1', '_t2', '_k3'])

    # ---------------------------------------
    # --- SETUP OPTIMIZER: Add pattern(s) parameters
    # ---------------------------------------
    # Each Pattern will have the position (tx,ty,tz) and rotation (r1,r2,r3)

    if not dataset['calibration_config']['calibration_pattern']['fixed']:  # Pattern not fixed -------------------------
        # If pattern is not fixed there will be a transform for each collection. To tackle this reference link called
        # according to what is on the dataset['calibration_config']['calibration_pattern']['link'] is prepended with
        # a "c<collection_name>" appendix. This is done automatically for the collection['transforms'] when
        # publishing ROS, but we must add this to the parameter name.
        parent = dataset['calibration_config']['calibration_pattern']['parent_link']
        child = dataset['calibration_config']['calibration_pattern']['link']
        transform_key = atom_core.naming.generateKey(parent, child)

        for collection_key, collection in dataset['collections'].items():  # iterate all collections

            # Set transform using the initial estimate of the transformations.
            initial_estimate = dataset['patterns']['transforms_initial'][collection_key]
            if not initial_estimate['detected'] or not parent == initial_estimate['parent'] or \
                    not child == initial_estimate['child']:
                raise ValueError('Cannot set initial estimate for pattern at collection ' + collection_key)

            # TODO remove this test only used for mmtbot in simulation, dataset test2
            # initial_estimate['trans'] = [-0.44961655139923096,1.479259967803955,1.1446490287780762]
            # initial_estimate['quat'] =[0.735267162322998,0.0,0.0,0.6777775287628174]

# pose:
#   position:
#     x: -0.44961655139923096
#     y: 1.479259967803955
#     z: 1.1446490287780762
#   orientation:
#     x: 0.735267162322998
#     y: 0.0
#     z: 0.0
#     w: 0.6777775287628174
# menu_entry_id: 0


            collection['transforms'][transform_key] = {'parent': parent, 'child': child,
                                                       'trans': initial_estimate['trans'],
                                                       'quat': initial_estimate['quat']}

            # Finally push the six parameters to describe the patterns pose w.r.t its parent link:
            #   a) The Getter will pick up transform from the collection collection_key
            #   b) The Setter will received a transform value and a collection_key and copy the transform to that of the
            #   corresponding collection
            initial_transform = getterTransform(dataset, transform_key=transform_key,
                                                collection_name=collection_key)
            bound_max = [x + sys.float_info.epsilon for x in initial_transform]
            bound_min = [x - sys.float_info.epsilon for x in initial_transform]

            opt.pushParamVector(group_name='c' + collection_key + '_' + transform_key, data_key='dataset',
                                getter=partial(getterTransform, transform_key=transform_key,
                                               collection_name=collection_key),
                                setter=partial(setterTransform, transform_key=transform_key,
                                               collection_name=collection_key),
                                suffix=['_x', '_y', '_z', '_r1', '_r2', '_r3'])

    else:  # fixed pattern ---------------------------------------------------------------------------------------------
        # if pattern is fixed it will not be replicated for all collections , i.e. there will be a single
        # reference link called according to what is on the dataset['calibration_config']['calibration_pattern'][
        # 'link']
        parent = dataset['calibration_config']['calibration_pattern']['parent_link']
        child = dataset['calibration_config']['calibration_pattern']['link']
        transform_key = atom_core.naming.generateKey(parent, child)

        # Set transform using the initial estimate of the transformations.
        initial_estimate = dataset['patterns']['transforms_initial'][selected_collection_key]
        if not initial_estimate['detected'] or not parent == initial_estimate['parent'] or \
                not child == initial_estimate['child']:
            raise ValueError('Cannot set initial estimate for pattern at collection ' + collection_key)

        # The pattern is fixed but we have a replicated transform for each collection. Lets add those.
        for collection_key, collection in dataset['collections'].items():
            collection['transforms'][transform_key] = {'parent': parent, 'child': child,
                                                       'trans': initial_estimate['trans'],
                                                       'quat': initial_estimate['quat']}

        # Finally push the six parameters to describe the patterns pose w.r.t its parent link:
        #   a) The Getter will pick up the collection from one selected collection (it does not really matter which,
        #       since they are replicas);
        #   b) The Setter will received a transform value and copy that to all collection replicas, to ensure they
        #       all have the same value. This is done by setting  "collection_name=None".
        opt.pushParamVector(group_name=transform_key, data_key='dataset',
                            getter=partial(getterTransform, transform_key=transform_key,
                                           collection_name=selected_collection_key),
                            setter=partial(setterTransform, transform_key=transform_key,
                                           collection_name=None),
                            suffix=['_x', '_y', '_z', '_r1', '_r2', '_r3'])

    # opt.printParameters()

    # ---------------------------------------
    # --- Define THE OBJECTIVE FUNCTION
    # ---------------------------------------
    opt.setObjectiveFunction(objectiveFunction)

    # ---------------------------------------
    # --- Define THE RESIDUALS
    # ---------------------------------------
    # Each residual is computed after the sensor and the pattern of a collection. Thus, each error will be affected
    # by the parameters tx,ty,tz,r1,r2,r3 of the sensor and the pattern


    print("Creating residuals ... ")

    if args['sample_seed'] is None:
        seed = random.randrange(sys.maxsize)
    else:
        seed = args['sample_seed']

    rng = random.Random(seed)
    print('  RNG Seed: ' + str(seed))

    for collection_key, collection in dataset['collections'].items():
        for sensor_key, sensor in dataset['sensors'].items():
            if not collection['labels'][sensor_key]['detected']:  # if pattern not detected by sensor in collection
                continue

            # Sensor related parameters
            sensors_transform_key = atom_core.naming.generateKey(sensor['calibration_parent'],
                                                                 sensor['calibration_child'])
            params = opt.getParamsContainingPattern(sensors_transform_key)

            # Intrinsics parameters
            if sensor['msg_type'] == 'Image' and args['optimize_intrinsics']:
                params.extend(opt.getParamsContainingPattern(sensor_key + '_intrinsics'))

            # Pattern related parameters
            if dataset['calibration_config']['calibration_pattern']['fixed']:
                pattern_transform_key = atom_core.naming.generateKey(
                    dataset['calibration_config']['calibration_pattern']['parent_link'],
                    dataset['calibration_config']['calibration_pattern']['link'])
            else:
                pattern_transform_key = 'c' + collection_key + '_' + atom_core.naming.generateKey(
                    dataset['calibration_config']['calibration_pattern']['parent_link'],
                    dataset['calibration_config']['calibration_pattern']['link'])

            params.extend(opt.getParamsContainingPattern(pattern_transform_key))  # pattern related params

            if sensor['msg_type'] == 'Image':  # if sensor is a camera use four residuals

                # Compute step as a function of residual sampling factor
                for idx in collection['labels'][sensor_key]['idxs']:  # using all pattern corners
                    rname = 'c' + str(collection_key) + '_' + str(sensor_key) + '_corner' + str(idx['id'])
                    opt.pushResidual(name=rname, params=params)

            elif sensor['msg_type'] == 'LaserScan':  # if sensor is a 2D lidar add two residuals
                # TODO Implement sampling here if relevant
                # Extrema points (longitudinal error)
                opt.pushResidual(name='c' + collection_key + '_' + sensor_key + '_eleft', params=params)
                opt.pushResidual(name='c' + collection_key + '_' + sensor_key + '_eright', params=params)

                # Inner points, use detection of edges (longitudinal error)
                for idx, _ in enumerate(collection['labels'][sensor_key]['edge_idxs']):
                    opt.pushResidual(name='c' + collection_key + '_' + sensor_key + '_inner_' + str(idx), params=params)

                # Laser beam (orthogonal error)
                for idx in range(0, len(collection['labels'][sensor_key]['idxs'])):
                    opt.pushResidual(name='c' + collection_key + '_' + sensor_key + '_beam_' + str(idx), params=params)

            elif sensor['msg_type'] == 'PointCloud2':  # if sensor is a 3D lidar add two types of residuals

                # Laser beam error (or orthogonal error) ==

                # The number of residuals for this error can be huge.
                # Therefore, we do a random sample to reduce the number.
                # The sample percentage has the interval [0.1, 1.0]
                population = range(0, len(collection['labels'][sensor_key]['idxs']))
                nsamples = int(max(0.1, min(args['sample_residuals'], 1.0)) * len(population))
                samples = rng.sample(population, nsamples)

                # Save it to be used in the objective function.
                collection['labels'][sensor_key]['samples'] = samples

                for idx in samples:
                    opt.pushResidual(name='c' + collection_key + '_' + sensor_key + '_oe_' + str(idx), params=params)

                # Extrema displacement error
                for idx in range(0, len(collection['labels'][sensor_key]['idxs_limit_points'])):
                    opt.pushResidual(name='c' + collection_key + '_' + sensor_key + '_ld_' + str(idx), params=params)

            # print('Adding residuals for sensor ' + sensor_key + ' with msg_type ' + sensor['msg_type'] +
            #       ' affected by parameters:\n' + str(params))


    # opt.printParameters()
    # opt.printResiduals()
    # exit(0)

    # ---------------------------------------
    # --- Compute the SPARSE MATRIX
    # ---------------------------------------
    print("Computing sparse matrix ... ")
    opt.computeSparseMatrix()
    # opt.printSparseMatrix()

    # ---------------------------------------
    # --- Get a normalizer for each residual type
    # ---------------------------------------
    msg_types = set([s['msg_type'] for s in dataset['sensors'].values()])
    normalizer = {k: 1.0 for k in msg_types}
    opt.addDataModel('normalizer', normalizer)

    residuals = objectiveFunction(opt.data_models)
    opt.callObjectiveFunction()
    for msg_type in msg_types:
        values = []
        for sensor_key, sensor in dataset['sensors'].items():
            if sensor['msg_type'] == msg_type:
                values += [residuals[k] for k in residuals.keys() if sensor_key in k]

        # If the normalizer is to big, the residuals may loose their influence when close to zero.
        normalizer[msg_type] = np.abs(np.log(np.mean(values)))

    # ---------------------------------------
    # --- Profile Objective Function
    # ---------------------------------------
    if args['profile_objective_function']:
        from line_profiler import LineProfiler
        lp1 = LineProfiler()
        lp_wrapper1 = lp1(objectiveFunction)
        lp_wrapper1(opt.data_models)

        print(Back.LIGHTYELLOW_EX + '\n\n' + Fore.RED + Style.BRIGHT + 'First call of Objective Function (no cache)'
              + '\n\n' + Back.RESET + Fore.RESET + Style.RESET_ALL)
        lp1.print_stats()

        lp2 = LineProfiler()
        lp_wrapper2 = lp2(objectiveFunction)
        lp_wrapper2(opt.data_models)
        print(Back.LIGHTYELLOW_EX + '\n\n' + Fore.RED + Style.BRIGHT + 'Second call of Objective Function (with cache)'
              + '\n\n' + Back.RESET + Fore.RESET + Style.RESET_ALL)
        lp2.print_stats()
        exit(0)

    # ---------------------------------------
    # --- Start Optimization
    # ---------------------------------------
    if args['phased_execution']:
        waitForKeyPress2(function=opt.callObjectiveFunction, timeout=0.1)

    print('Initializing optimization ...')
    # options = {'ftol': 1e-3, 'xtol': 1e-3, 'gtol': 1e-3, 'diff_step': None, 'x_scale': 'jac'}
    # options = {'ftol': 1e-8, 'xtol': 1e-8, 'gtol': 1e-8, 'diff_step': None, 'x_scale': 'jac'}
    options = {'ftol': 1e-6, 'xtol': 1e-6, 'gtol': 1e-6, 'diff_step': None, 'x_scale': 'jac'}
    # options = {'ftol': 1e-5, 'xtol': 1e-5, 'gtol': 1e-5, 'diff_step': None, 'x_scale': 'jac'}
    opt.startOptimization(optimization_options=options)


    if args['phased_execution']:
        waitForKeyPress2(opt.callObjectiveFunction, timeout=0.1, message=Fore.RED + 'Optimization finished' +
                                                                        Style.RESET_ALL)
    # Error Report
    print('\nFinal errors:')
    r = objectiveFunction(opt.data_models)
    for sensor_key, sensor in dataset['sensors'].items():
        keys = [k for k in r.keys() if sensor_key in k]
        v = [r[k] * normalizer[sensor['msg_type']] for k in keys]
        print("Sensor " + sensor_key + " " + str(np.mean(v)))

    # ---------------------------------------
    # --- Save updated JSON file
    # ---------------------------------------

    if args['output_json'] is None:
        filename_results_json = os.path.dirname(json_file) + '/atom_calibration.json'
    else:
        filename_results_json = args['output_json']

    saveResultsJSON(filename_results_json, dataset)

    # ---------------------------------------
    # --- Save updated xacro
    # ---------------------------------------
    # Cycle all sensors in calibration config, and for each replace the optimized transform in the original xacro
    # Parse xacro description file
    description_file, _, _ = uriReader(dataset['calibration_config']['description_file'])
    rospy.loginfo('Reading description file ' + description_file + '...')
    xml_robot = readXacroFile(description_file)

    for sensor_key in dataset['calibration_config']['sensors']:
        child = dataset['calibration_config']['sensors'][sensor_key]['child_link']
        parent = dataset['calibration_config']['sensors'][sensor_key]['parent_link']
        transform_key = atom_core.naming.generateKey(parent, child)

        trans = list(dataset['collections'][selected_collection_key]['transforms'][transform_key]['trans'])
        quat = list(dataset['collections'][selected_collection_key]['transforms'][transform_key]['quat'])
        found = False

        for joint in xml_robot.joints:
            if joint.parent == parent and joint.child == child:
                found = True
                print('Found joint: ' + str(joint.name))

                print('Replacing xyz = ' + str(joint.origin.xyz) + ' by ' + str(trans))
                joint.origin.xyz = trans

                rpy = list(tf.transformations.euler_from_quaternion(quat, axes='sxyz'))
                print('Replacing rpy = ' + str(joint.origin.rpy) + ' by ' + str(rpy))
                joint.origin.rpy = rpy
                break

        if not found:
            raise ValueError('Could not find transform ' + str(transform_key) + ' in ' + description_file)

    if args['output_xacro'] is None:
        filename_results_xacro = '/tmp/optimized.urdf.xacro'
    else:
        filename_results_xacro = args['output_xacro']

    with open(filename_results_xacro, 'w') as out:
        out.write(URDF.to_xml_string(xml_robot))

    print('Optimized xacro file saved to ' + str(filename_results_xacro) + ' . You can use it as a ROS robot_description.')


if __name__ == "__main__":
    main()
