#!/usr/bin/env python3
"""
Casts an optimization problem using an ATOM dataset file as input. Then calibrates by running the optimization.
"""

# Standard imports
import argparse
import copy
import os
import random
import signal
import sys
from functools import partial
from copy import deepcopy
import types

import numpy as np

# Atom imports
from colorama import Fore, Style
import yaml
from atom_core.optimization_utils import Optimizer, addArguments
from atom_calibration.calibration.getters_and_setters import (
    getterCameraIntrinsics, getterTransform, setterCameraIntrinsics, setterTransform,
    getterJointParam, setterJointParam)
from atom_calibration.calibration.objective_function import errorReport, objectiveFunction, replaceTransformsFromJoints
from atom_calibration.calibration.visualization import setupVisualization, visualizationFunction
from atom_core.dataset_io import (addNoiseToInitialGuess,
                                  addBiasToJointParameters,
                                  addNoiseToTF,
                                  addNoiseFromNoisyTFLinks,
                                  checkIfAtLeastOneLabeledCollectionPerSensor,
                                  filterCollectionsFromDataset,
                                  filterSensorsFromDataset,
                                  loadResultsJSON,
                                  saveAtomDataset,
                                  filterJointsFromDataset,
                                  filterJointParametersFromDataset,
                                  filterAdditionalTfsFromDataset,
                                  filterPatternsFromDataset)
from atom_core.naming import generateName, generateKey
from atom_core.utilities import (atomError, createLambdaExpressionsForArgs,
                                 waitForKeyPress2,
                                 atomStartupPrint,
                                 verifyAnchoredSensor,
                                 printComparisonToGroundTruth, saveCommandLineArgsYml)
from atom_core.xacro_io import saveResultsXacro
from atom_core.results_yml_io import saveResultsYml
from atom_core.config_io import parse_list_of_transformations, mutually_inclusive_conditions


# -------------------------------------------------------------------------------
# --- FUNCTIONS
# -------------------------------------------------------------------------------
def signal_handler(sig, frame):
    print("Stopping optimization (Ctrl+C pressed)")
    sys.exit(0)


# -------------------------------------------------------------------------------
# --- MAIN
# -------------------------------------------------------------------------------
def main():

    atomStartupPrint('Starting calibration')

    # ---------------------------------------
    # --- Parse command line argument
    # ---------------------------------------
    signal.signal(signal.SIGINT, signal_handler)

    ap = argparse.ArgumentParser()
    ap = addArguments(ap)
    ap.add_argument("-json", "--json_file", type=str, required=True,
                    help="Json file containing input dataset.", )
    ap.add_argument("-v", "--verbose", help="Be verbose",
                    action="store_true", default=False)
    ap.add_argument("-snv", "--show_normalized_values", action="store_true", default=False,
                    help="In the output table, shows normalized residuals alongside the original values.")
    ap.add_argument(
        "-rnm", "--rgb_normalizer_multiplier", type=float, default=1.0,
        help="Multiplier of rgb normalizer. Used to enhance the importance of rgb modality errors vs metric errors. Smaller values will make rgb errors more important.")
    ap.add_argument("-rv", "--ros_visualization", action="store_true",
                    help="Publish ros visualization markers.")
    ap.add_argument("-dpcc", "--draw_per_collection_colors", action="store_true",
                    help="Publish ros visualization markers with one color per collection.")
    ap.add_argument("-da", "--draw_alpha", type=float, default=0.5,
                    help="Publish ros visualization markers with alpha for better visualization.")
    ap.add_argument("-si", "--show_images", action="store_true",
                    default=False, help="shows images for each camera")
    ap.add_argument("-oi", "--optimize_intrinsics", action="store_true", default=False,
                    help="Adds camera intrinsics to the  optimization",)
    ap.add_argument("-sr", "--sample_residuals",
                    help="Samples residuals", type=float, default=1)
    ap.add_argument("-ss", "--sample_seed", help="Sampling seed", type=int)
    ap.add_argument("-slr", "--sample_longitudinal_residuals",
                    help="Samples residuals", type=float, default=1)
    ap.add_argument("-ajf", "--all_joints_fixed", action="store_true", default=False,
                    help="Assume all joints are fixed and because of that draw a single robot mesh."
                    "Overrides automatic detection of static robot.")
    ap.add_argument("-oas", "--only_anchored_sensor", action="store_true", default=False,
                    help="Runs optimization only using the anchored sensor and discarding all others.")
    ap.add_argument(
        "-as", "--anchor_sensors", nargs='+',
        help='Anchor sensors. You can name as many sensors as needed. Sensors names should be the ones defined in the config.yml. Example: -as rgb_hand lidar_center',
        type=str, required=False)
    ap.add_argument("-ap", "--anchor_patterns", action="store_true", default=False,
                    help="Runs optimization without changing the poses of the patterns.")
    ap.add_argument(
        "-gtpp", "--ground_truth_pattern_poses", action="store_true", default=False,
        help="Sets pattern poses from initial transforms stored in the dataset. Only for simulation, these are the values before eventually adding noise, which in the case of simulation consist of the ground truth.")
    ap.add_argument("-uic", "--use_incomplete_collections", action="store_true", default=False,
                    help="Remove any collection which does not have a detection for all sensors.", )
    ap.add_argument("-ias", "--ignore_anchored_sensor", action="store_true", default=False,
                    help="Ignore the anchored sensor information in the dataset.", )
    ap.add_argument(
        "-rpd", "--remove_partial_detections",
        help="Remove detected labels which are only partial."
        "Used or the Charuco.", action="store_true", default=False)
    ap.add_argument(
        "-nig", "--noisy_initial_guess", nargs=2, metavar=("translation", "rotation"),
        help="Magnitude of noise to add to the initial guess atomic transformations set before starting optimization [meters, radians].",
        type=float, default=[0.0, 0.0],),
    ap.add_argument("-jbn", "--joint_bias_names", nargs='+',
                    help='Joints to add bias to', type=str, required=False)
    ap.add_argument(
        "-jbp", "--joint_bias_params", nargs='+',
        help='Joint parameters to add bias to. One of [origin_x, origin_y, origin_z, origin_roll, origin_pitch, origin_yaw].',
        type=str, required=False)
    ap.add_argument("-jbv", "--joint_bias_values", nargs='+',
                    help='Operates in tandem with "joint_bias_names"', type=float, required=False)
    
    ap.add_argument("-ntfl", "--noisy_tf_links", type=parse_list_of_transformations, help='''Pairs of links defining the transformations to which noise will be added, in the format linkParent1:linkChild2,linkParentA:linkChildB (links defining a transformation separated by
                    : and different pairs separeted by ,)
                    Note : This flag overrides the -nig flag''')
    ap.add_argument("-ntfv", "--noisy_tf_values", nargs=2, type=float, metavar=("translation", "rotation"),
                    help="Translation(m) and Rotation(rad)(respectively) noise values to add to the transformations specified by-ntfl")

    ap.add_argument(
        "-ssf", "--sensor_selection_function", default=None, type=str,
        help="A string to be evaluated into a lambda function that receives a sensor name as input and "
        "returns True or False to indicate if the sensor should be loaded (and used in the "
        "optimization). The Syntax is lambda name: f(x), where f(x) is the function in python "
        'language. Example: lambda name: name in ["left_laser", "frontal_camera"] , to load only '
        "sensors left_laser and frontal_camera")
    ap.add_argument(
        "-csf", "--collection_selection_function", default=None, type=str,
        help="A string to be evaluated into a lambda function that receives a collection name as input and "
        "returns True or False to indicate if the collection should be loaded (and used in the "
        "optimization). The Syntax is lambda name: f(x), where f(x) is the function in python "
        "language. Example: lambda name: int(name) > 5 , to load only collections 6, 7, and onward.")
    ap.add_argument(
        "-jsf", "--joint_selection_function", default=None, type=str,
        help="A string to be evaluated into a lambda function that receives a joint name as input and "
        "returns True or False to indicate if the joint should be calibrated (and used in the "
        "optimization). The Syntax is lambda name: f(x), where f(x) is the function in python "
        'language. Example: lambda name: name in ["left_arm_roll", "right_shoulder_lift"] , to load only '
        "joints left_arm_roll and right_shoulder_lift")
    ap.add_argument(
        "-jpsf", "--joint_parameter_selection_function", default=None, type=str,
        help="A string to be evaluated into a lambda function that receives a joint parameter name as input and returns True or False to indicate if the joint parameter should be calibrated (and used in the "
        "optimization). The Syntax is lambda name: f(x), where f(x) is the function in python "
        'language. Example: lambda name: name in ["origin_x", "origin_roll"] , to use only these parameters')
    ap.add_argument(
        "-psf", "--pattern_selection_function", default=None, type=str,
        help="A string to be evaluated into a lambda function that receives a pattern name as input (as defined in the config.yml) and returns True or False to indicate if the pattern should be used in the optimization). The Syntax is lambda name: f(x), where f(x) is the function in python "
        "language. Example: lambda name: name in ['pattern_1', 'charuco_200x300'] , to load only these two patterns")
    ap.add_argument(
        "-atsf", "--additional_tf_selection_function", default=None, type=str,
        help="A string to be evaluated into a lambda function that receives an additional_tf name as input and returns True or False to indicate if the additional_tf should be calibrated (and used in the "
        "optimization). The Syntax is lambda name: f(x), where f(x) is the function in python "
        'language. Example: lambda name: name in ["base_link_to_base_link_mb", "left_wheel_to_base_link"] , to load only '
        "additional_tfs base_link_to_base_link_mb and left_wheel_to_base_link")
    ap.add_argument(
        "-phased", "--phased_execution",
        help="Stay in a loop before calling optimization, and in another "
        "after calling the optimization. Good for debugging.", action="store_true", default=False)
    ap.add_argument("-ipg", "--initial_pose_ghost", action="store_true", default=False,
                    help="Draw a ghost mesh with the systems initial pose. Good for debugging.")
    ap.add_argument(
        "-ctgt", "--comparison_to_ground_truth",
        help="Will print a comparison to ground truth at the end of the optimization. Warning: only makes sense if the values in the dataset are the ground truth, which is the case in simulated systems.",
        action="store_true", default=False)
    ap.add_argument("-sce", "--save_calibration_errors", action="store_true", default=False,
                    help="Saves a csv file with the errors at every iteration.")
    ap.add_argument("-ftol", "--optimization_ftol", help="ftol parameter for optimization.",
                    type=float, required=False, default=1e-5)
    ap.add_argument("-xtol", "--optimization_xtol", help="xtol parameter for optimization.",
                    type=float, required=False, default=1e-5)
    ap.add_argument("-gtol", "--optimization_gtol", help="gtol parameter for optimization.",
                    type=float, required=False, default=1e-5)
    ap.add_argument(
        "-diff_step", "--optimization_diff_step", help="diff_step parameter for optimization.",
        type=float, required=False, default=None)
    ap.add_argument(
        "-max_nfev", "--optimization_max_nfev", help="max_nfev parameter for optimization.",
        type=int, required=False, default=None)
    ap.add_argument("-sfr", "--save_file_results", help="Store the results",
                    action='store_true', default=False)
    ap.add_argument(
        "-sfrn", "--save_file_results_name",
        help="Name of csv file to save the results."
        "Default: -test_json/results/{name_of_dataset}_{sensor_source}_to_{sensor_target}_results.csv",
        type=str, required=False)

    # Roslaunch adds two arguments (__name and __log) that break our parser. Lets remove those.
    arglist = [x for x in sys.argv[1:] if not x.startswith("__")]
    # these args have the selection functions as strings
    args_original = vars(ap.parse_args(args=arglist))
    args = createLambdaExpressionsForArgs(args_original)  # selection functions are now lambdas

    # ---------------------------------------
    # --- Setup some variables
    # ---------------------------------------
    args['dataset_folder'] = os.path.dirname(args['json_file'])
    output_folder = os.path.dirname(args['json_file'])

    # -sce wasn't working because args previously had this key
    args['output_folder'] = output_folder

    # ---------------------------------------
    # --- Read data from file
    # ---------------------------------------

    # Loads a json file containing the detections. Returned json_file has path resolved by urireader.
    dataset, _ = loadResultsJSON(args["json_file"], args["collection_selection_function"])

    if float(dataset['_metadata']['version']) < 3.0:
        atomError('Your dataset not version 3.0. Before running a calibration, you need to update it first with:\nrosrun atom_calibration update_dataset_from_version_2_to_3.')

    # make a copy before filtering out collections and sensors
    dataset_original = copy.deepcopy(dataset)

    # ---------------------------------------
    # --- Filter some collections, sensors, joints and additional tfs from the dataset
    # ---------------------------------------
    dataset = filterCollectionsFromDataset(dataset, args)  # filter collections
    dataset = filterSensorsFromDataset(dataset, args)  # filter sensors
    dataset = filterPatternsFromDataset(dataset, args)
    dataset = filterAdditionalTfsFromDataset(dataset, args)

    print("Loaded dataset containing " + str(len(dataset["sensors"].keys())) +
          " sensors and " + str(len(dataset["collections"].keys())) + " collections.")

    # ---------------------------------------
    # --- Verifications
    # ---------------------------------------
    checkIfAtLeastOneLabeledCollectionPerSensor(dataset)

    # ---------------------------------------
    # --- Define selected collection key.
    # ---------------------------------------
    # For the getters we only need to get one collection because optimized transformations are static, which means they are the same for all collections. Let's select the first key in the dictionary and always get that transformation.
    selected_collection_key = list(dataset["collections"].keys())[0]
    print("Selected collection key is " + str(selected_collection_key))

    # ---------------------------------------
    # --- Store initial values for transformations to be optimized
    # ---------------------------------------
    for collection_key, collection in dataset["collections"].items():
        initial_transform_key = generateName("transforms", suffix="ini")
        collection[initial_transform_key] = copy.deepcopy(
            collection["transforms"])
        for transform_key, transform in collection[initial_transform_key].items():
            transform["parent"] = generateName(
                transform["parent"], suffix="ini")
            transform["child"] = generateName(transform["child"], suffix="ini")

    # ---------------------------------------
    # --- Add noise to the transformations and joint parameters to be calibrated.
    # ---------------------------------------

    dataset_ground_truth = deepcopy(dataset)  # make a copy before adding noise

    addNoiseToInitialGuess(dataset, args, selected_collection_key)

    addNoiseFromNoisyTFLinks(dataset,args,selected_collection_key)

    # Must add noise, transfer new joints to transforms, and only after remove unselected joints
    addBiasToJointParameters(dataset, args)
    replaceTransformsFromJoints(dataset)
    dataset = filterJointsFromDataset(dataset, args)
    dataset = filterJointParametersFromDataset(dataset, args)
    dataset_ground_truth = filterJointsFromDataset(dataset_ground_truth, args)
    dataset_ground_truth = filterJointParametersFromDataset(dataset_ground_truth, args)

    # ---------------------------------------
    # --- SETUP OPTIMIZER: Create data models
    # ---------------------------------------
    opt = Optimizer()
    opt.addDataModel("args", args)
    opt.addDataModel("dataset", dataset)

    # ---------------------------------------
    # --- DEFINE THE VISUALIZATION FUNCTION
    # ---------------------------------------
    if args["view_optimization"]:
        opt.setInternalVisualization(True)
    else:
        opt.setInternalVisualization(False)

    if args["ros_visualization"]:
        print("Configuring visualization ... ")
        graphics = setupVisualization(dataset, args, selected_collection_key)
        opt.addDataModel("graphics", graphics)

        opt.setVisualizationFunction(
            visualizationFunction, args["ros_visualization"],
            niterations=1, figures=[])

    # ---------------------------------------
    # --- SETUP OPTIMIZER: Process anchored sensors
    # ---------------------------------------
    # Each sensor will have a position (tx,ty,tz) and a rotation (r1,r2,r3)

    # Add parameters related to the sensors
    # remove anchored sensor if flagged as such.

    # Create a list of anchored transform keys, which is built from the anchored sensor and the argument --anchor_sensors.
    anchored_transform_keys = []

    if args['ignore_anchored_sensor']:
        dataset["calibration_config"]["anchored_sensor"] = None

    if dataset["calibration_config"]["anchored_sensor"] is not None:

        print("Anchored sensor is " + Fore.GREEN + dataset["calibration_config"]["anchored_sensor"]
              + Style.RESET_ALL)

        # Verify is anchored sensor is on dataset
        verifyAnchoredSensor(dataset["calibration_config"]["anchored_sensor"], dataset['sensors'])

        if dataset["calibration_config"]["anchored_sensor"] in dataset["sensors"]:
            anchored_parent = dataset["sensors"][
                dataset["calibration_config"]["anchored_sensor"]]["calibration_parent"]
            anchored_child = dataset["sensors"][
                dataset["calibration_config"]["anchored_sensor"]]["calibration_child"]
            anchored_transform_keys.append(generateKey(anchored_parent, anchored_child))
        else:
            atomError('Anchored sensor ' + dataset["calibration_config"]
                      ["anchored_sensor"] + ' does not exist in dataset.')

    # TODO If we want an anchored sensor we should search (and fix) all the transforms in its chain that are
    #  being optimized

    if args['anchor_sensors'] is not None:
        for sensor in args['anchor_sensors']:
            if sensor in dataset["sensors"]:
                parent = dataset["sensors"][sensor]["calibration_parent"]
                child = dataset["sensors"][sensor]["calibration_child"]
                anchored_transform_keys.append(generateKey(parent, child))
            else:
                atomError('Sensor ' + sensor + ' given in --anchor_sensors does not exist in dataset.')

    if anchored_transform_keys:
        print("These transform_keys are anchored: " + str(anchored_transform_keys))

    # ---------------------------------------
    # --- SETUP OPTIMIZER: Add sensor parameters
    # ---------------------------------------
    # Steaming from the config json, we define a transform to be optimized for each sensor. It could happen that two
    # or more sensors define the same transform to be optimized (#120). To cope with this we first create a list of
    # transformations to be optimized and then compute the unique set of that list.
    print("Creating sensor transformation parameters ...")
    sensors_transforms_set = set()
    for sensor_key, sensor in dataset["sensors"].items():
        transform_key = generateKey(
            sensor["calibration_parent"], sensor["calibration_child"])
        sensors_transforms_set.add(transform_key)

    # push six parameters for each transform to be optimized.
    for transform_key in sensors_transforms_set:
        ground_truth_transform = getterTransform(dataset_ground_truth, transform_key=transform_key,
                                                 collection_name=selected_collection_key)

        initial_transform = getterTransform(dataset, transform_key=transform_key,
                                            collection_name=selected_collection_key)

        if transform_key in anchored_transform_keys:
            bound_max = [x + 2 * sys.float_info.epsilon for x in initial_transform]
            bound_min = [x - 2 * sys.float_info.epsilon for x in initial_transform]
        else:
            bound_max = [+np.inf for x in initial_transform]
            bound_min = [-np.inf for x in initial_transform]

        opt.pushParamVector(
            group_name=transform_key, data_key="dataset", bound_max=bound_max, bound_min=bound_min,
            getter=partial(getterTransform, transform_key=transform_key,
                           collection_name=selected_collection_key,),
            setter=partial(setterTransform,
                           transform_key=transform_key, collection_name=None),
            suffix=["_x", "_y", "_z", "_r1", "_r2", "_r3"],
            ground_truth_values=ground_truth_transform)

    # ---------------------------------------
    # --- SETUP OPTIMIZER: Add additional_tfs parameters
    # ---------------------------------------
    print("Creating additional_tfs parameters ...")
    additional_transforms_set = set()

    if dataset['calibration_config']['additional_tfs'] is not None:
        for _, additional_tf in dataset['calibration_config']['additional_tfs'].items():
            transform_key = generateKey(additional_tf['parent_link'], additional_tf['child_link'])
            additional_transforms_set.add(transform_key)

    if dataset['calibration_config']['additional_tfs'] is not None:
        for _, additional_tf in dataset['calibration_config']['additional_tfs'].items():
            transform_key = generateKey(additional_tf['parent_link'], additional_tf['child_link'])

            # because of #900, and for retrocompatibility with old datasets, we will assume that if the transforms
            # field does not exist in the dataset, then the transformation is fixed
            if 'transforms' not in dataset or dataset['transforms'][transform_key]['type'] == 'fixed':

                # push six parameters for each transform to be optimized.
                ground_truth_transform = getterTransform(
                    dataset_ground_truth, transform_key=transform_key,
                    collection_name=selected_collection_key)

                initial_transform = getterTransform(dataset, transform_key=transform_key,
                                                    collection_name=selected_collection_key)

                if transform_key in anchored_transform_keys:
                    bound_max = [x + 2 * sys.float_info.epsilon for x in initial_transform]
                    bound_min = [x - 2 * sys.float_info.epsilon for x in initial_transform]
                else:
                    bound_max = [+np.inf for x in initial_transform]
                    bound_min = [-np.inf for x in initial_transform]

                opt.pushParamVector(
                    group_name=transform_key, data_key="dataset", bound_max=bound_max,
                    bound_min=bound_min,
                    getter=partial(
                        getterTransform, transform_key=transform_key,
                        collection_name=selected_collection_key,),
                    setter=partial(
                        setterTransform, transform_key=transform_key, collection_name=None),
                    suffix=["_x", "_y", "_z", "_r1", "_r2", "_r3"],
                    ground_truth_values=ground_truth_transform)

            elif dataset['transforms'][transform_key]['type'] == 'multiple':

                # iterate all collections
                for collection_key, collection in dataset["collections"].items():

                    # push six parameters for each transform to be optimized.
                    ground_truth_transform = getterTransform(
                        dataset_ground_truth, transform_key=transform_key,
                        collection_name=collection_key)

                    initial_transform = getterTransform(dataset, transform_key=transform_key,
                                                        collection_name=collection_key)

                    if transform_key in anchored_transform_keys:
                        bound_max = [x + 2 * sys.float_info.epsilon for x in initial_transform]
                        bound_min = [x - 2 * sys.float_info.epsilon for x in initial_transform]
                    else:
                        bound_max = [+np.inf for x in initial_transform]
                        bound_min = [-np.inf for x in initial_transform]

                    opt.pushParamVector(
                        group_name="c" + collection_key + "_" + transform_key, data_key="dataset",
                        bound_max=bound_max, bound_min=bound_min,
                        getter=partial(
                            getterTransform, transform_key=transform_key,
                            collection_name=collection_key),
                        setter=partial(
                            setterTransform, transform_key=transform_key,
                            collection_name=collection_key),
                        suffix=["_x", "_y", "_z", "_r1", "_r2", "_r3"],
                        ground_truth_values=ground_truth_transform)
            else:
                atomError('Unknown transform type ' + dataset['transforms'][transform_key]['type'])

    # ---------------------------------------
    # --- SETUP OPTIMIZER: Add intrinsic rgb sensor parameters
    # ---------------------------------------
    # TODO bound_min and max for intrinsics

    if args["optimize_intrinsics"]:
        for sensor_key, sensor in dataset["sensors"].items():
            if sensor["modality"] == "rgb":  # if sensor is a camera add intrinsics

                ground_truth_values = getterCameraIntrinsics(dataset_ground_truth, sensor_key)
                opt.pushParamVector(
                    group_name=str(sensor_key) + "_intrinsics", data_key="dataset",
                    getter=partial(getterCameraIntrinsics, sensor_key=sensor_key),
                    setter=partial(setterCameraIntrinsics, sensor_key=sensor_key),
                    suffix=["_fx", "_fy", "_cx", "_cy", "_k1", "_k2", "_t1", "_t2", "_k3"],
                    ground_truth_values=ground_truth_values)

    # ---------------------------------------
    # --- SETUP OPTIMIZER: Add pattern(s) parameters
    # ---------------------------------------

    # Each Pattern will have the position (tx,ty,tz) and rotation (r1,r2,r3)
    for pattern_key, pattern in dataset['calibration_config']['calibration_patterns'].items():

        # Pattern not fixed -----------------
        if not pattern["fixed"]:
            # If pattern is not fixed there will be a transform for each collection. To tackle this reference link called
            # according to what is on the dataset['calibration_config']['calibration_pattern']['link'] is prepended with
            # a "c<collection_name>" appendix. This is done automatically for the collection['transforms'] when
            # publishing ROS, but we must add this to the parameter name.
            parent = pattern["parent_link"]
            child = pattern["link"]
            transform_key = generateKey(parent, child)

            # iterate all collections
            for collection_key, collection in dataset["collections"].items():

                # Set transform using the initial estimate of the transformations.

                initial_estimate = dataset["patterns"][pattern_key]["transforms_initial"][
                    collection_key]
                if not initial_estimate["detected"] or not parent == initial_estimate["parent"] or \
                        not child == initial_estimate["child"]:
                    atomError("Cannot set initial estimate for pattern " + Fore.BLUE + pattern_key +
                              Style.RESET_ALL + "at collection " + collection_key +
                              '. You must run the calibration without this collection. Check how to use the csf flag.')

                if transform_key not in dataset_ground_truth['collections'][
                        selected_collection_key]['transforms'].keys():
                    ground_truth_transform = None
                else:
                    ground_truth_transform = getterTransform(
                        dataset_ground_truth, transform_key=transform_key,
                        collection_name=collection_key)

                if not args['ground_truth_pattern_poses']:
                    collection["transforms"][transform_key] = {"parent": parent, "child": child,
                                                               "trans": initial_estimate["trans"],
                                                               "quat": initial_estimate["quat"]}

                # Finally push the six parameters to describe the patterns pose w.r.t its parent link:
                #   a) The Getter will pick up transform from the collection collection_key;
                #   b) The Setter will received a transform value and a collection_key and copy the transform to that of the corresponding collection.
                if not args['anchor_patterns']:

                    opt.pushParamVector(
                        group_name="c" + collection_key + "_" + transform_key, data_key="dataset",
                        suffix=["_x", "_y", "_z", "_r1", "_r2", "_r3"],
                        getter=partial(
                            getterTransform, transform_key=transform_key,
                            collection_name=collection_key),
                        setter=partial(
                            setterTransform, transform_key=transform_key,
                            collection_name=collection_key),
                        ground_truth_values=ground_truth_transform)

        else:  # fixed pattern ----------------------------------------------------------------------------------
            # if pattern is fixed it will not be replicated for all collections , i.e. there will be a single
            # reference link called according to what is on the dataset['calibration_config']['calibration_pattern'][
            # 'link']
            parent = pattern["parent_link"]
            child = pattern["link"]
            transform_key = generateKey(parent, child)

            # Set transform using the initial estimate of the transformations.
            # Because the pattern is fixed we can select any collection which provides a transform initial and use that one (see https://github.com/lardemua/atom/issues/866 )
            flg_have_initial_estimate = False
            for collection_key, collection in dataset["collections"].items():
                if dataset["patterns"][pattern_key]["transforms_initial"][collection_key]['detected']:
                    initial_estimate = dataset["patterns"][pattern_key]["transforms_initial"][
                        collection_key]
                    flg_have_initial_estimate = True
                    break  # select the first available initial_estimate

            if flg_have_initial_estimate == False:
                atomError(
                    "Cannot set initial estimate for pattern " + Fore.BLUE + pattern_key + Style.RESET_ALL)

            # TODO not sure why this check is here ...
            if not parent == initial_estimate["parent"] or not child == initial_estimate["child"]:
                atomError(
                    "Initial estimate for pattern " + Fore.BLUE + pattern_key + Style.RESET_ALL +
                    " does not have expected parent and child frame ids.")

            if transform_key not in dataset_ground_truth['collections'][selected_collection_key][
                    'transforms'].keys():
                ground_truth_transform = None
            else:
                ground_truth_transform = getterTransform(
                    dataset_ground_truth, transform_key=transform_key,
                    collection_name=selected_collection_key)

            # The pattern is fixed but we have a replicated transform for each collection. Lets add those.
            if not args['ground_truth_pattern_poses']:  # nothing to do if we should use ground truth
                for collection_key, collection in dataset["collections"].items():
                    collection["transforms"][transform_key] = {"parent": parent, "child": child,
                                                               "trans": initial_estimate["trans"],
                                                               "quat": initial_estimate["quat"]}

            # Finally push the six parameters to describe the patterns pose w.r.t its parent link:
            #   a) The Getter will pick up the collection from one selected collection (it does not really matter which,
            #       since they are replicas);
            #   b) The Setter will received a transform value and copy that to all collection replicas, to ensure they
            #       all have the same value. This is done by setting  "collection_name=None".
            if not args['anchor_patterns']:
                opt.pushParamVector(group_name=transform_key, data_key="dataset",
                                    getter=partial(getterTransform, transform_key=transform_key,
                                                   collection_name=selected_collection_key),
                                    setter=partial(setterTransform, transform_key=transform_key,
                                                   collection_name=None),
                                    suffix=["_x", "_y", "_z", "_r1", "_r2", "_r3"],
                                    ground_truth_values=ground_truth_transform)

    # ---------------------------------------
    # --- SETUP OPTIMIZER: Add joint(s) parameters
    # ---------------------------------------
    if not dataset['calibration_config']['joints'] == "":

        for joint_key, joint in dataset['collections'][selected_collection_key]['joints'].items():
            joint_config = dataset['calibration_config']['joints'][joint_key]

            for param_key in joint_config['params_to_calibrate']:
                ground_truth_value = getterJointParam(
                    dataset_ground_truth, joint_key, param_key, selected_collection_key)

                group_name = 'joint-' + joint_key + '-' + param_key
                getter = partial(getterJointParam, joint_key=joint_key,
                                 param_key=param_key, collection_name=selected_collection_key)
                setter = partial(setterJointParam, joint_key=joint_key,
                                 param_key=param_key, collection_name=None)
                opt.pushParamScalar(group_name=group_name, data_key='dataset', getter=getter,
                                    setter=setter, ground_truth_value=ground_truth_value)

        # TODO add joint limits from xacro
    # opt.printParameters()

    # ---------------------------------------
    # --- Define THE OBJECTIVE FUNCTION
    # ---------------------------------------
    opt.setObjectiveFunction(objectiveFunction)

    # ---------------------------------------
    # --- Define THE RESIDUALS
    # ---------------------------------------
    # Each residual is computed after the sensor and the pattern of a collection. Thus, each error will be affected
    # by the parameters tx,ty,tz,r1,r2,r3 of the sensor and the pattern

    print("Creating residuals ... ")

    if args["sample_seed"] is None:
        seed = random.randrange(sys.maxsize)
    else:
        seed = args["sample_seed"]

    rng = random.Random(seed)
    # print("RNG Seed: " + str(seed))

    for collection_key, collection in dataset["collections"].items():
        for pattern_key, pattern in dataset['calibration_config']['calibration_patterns'].items():
            for sensor_key, sensor in dataset["sensors"].items():
                # if pattern not detected by sensor in collection
                if not collection["labels"][pattern_key][sensor_key]["detected"]:
                    continue

                # Sensor related parameters
                # sensors_transform_key = generateKey(
                #     sensor["calibration_parent"], sensor["calibration_child"])
                # params = opt.getParamsContainingPattern(sensors_transform_key)

                # Issue #543: Create the list of transformations that influence this residual by analyzing the transformation chain.
                params = []

                # Sensor based residuals
                transforms_list = list(sensors_transforms_set)
                for transform_in_chain in sensor['chain']:
                    transform_key = generateKey(
                        transform_in_chain["parent"], transform_in_chain["child"])
                    if transform_key in transforms_list:
                        params.extend(opt.getParamsContainingPattern(transform_key))

                # Additional_tf based residuals
                if dataset['calibration_config']['additional_tfs'] is not None:
                    for _, additional_tf in dataset['calibration_config']['additional_tfs'].items():
                        transform_key = generateKey(
                            additional_tf['parent_link'],
                            additional_tf['child_link'])
                        if dataset['transforms'][transform_key]['type'] == 'fixed':
                            params.extend(opt.getParamsContainingPattern(transform_key))
                        elif dataset['transforms'][transform_key]['type'] == 'multiple':
                            params.extend(opt.getParamsContainingPattern(
                                "c" + collection_key + "_" + transform_key))

                # Intrinsics parameters
                if sensor["modality"] == "rgb" and args["optimize_intrinsics"]:
                    params.extend(opt.getParamsContainingPattern(
                        sensor_key + "_intrinsics"))

                # Pattern related parameters
                if pattern["fixed"]:
                    pattern_transform_key = generateKey(pattern["parent_link"],
                                                        pattern["link"])
                else:
                    pattern_transform_key = "c" + collection_key + "_" + \
                        generateKey(pattern["parent_link"], pattern["link"])

                params.extend(opt.getParamsContainingPattern(
                    pattern_transform_key))  # pattern related params

                # TODO Append joint params. Right now appending all joint params. Should be clever and know, from residual, which will affect
                if not dataset['calibration_config']['joints'] == "":
                    for joint_key, joint in dataset['collections'][selected_collection_key][
                            'joints'].items():
                        pattern_joint = 'joint-' + joint_key
                        params.extend(opt.getParamsContainingPattern(
                            pattern_joint))  # pattern related params

                if sensor["modality"] == "rgb":
                    # Compute step as a function of residual sampling factor
                    # using all pattern corners
                    for idx in collection["labels"][pattern_key][sensor_key]["idxs"]:
                        rname = "c" + str(collection_key) + "_" + "p_" + pattern_key + "_" +\
                            str(sensor_key) + "_corner" + str(idx["id"])
                        opt.pushResidual(name=rname, params=params)

                elif sensor["modality"] == "lidar3d":

                    # Laser beam error (or orthogonal error) ==

                    # The number of residuals for this error can be huge.
                    # Therefore, we do a random sample to reduce the number.
                    # The sample percentage has the interval [0.1, 1.0]
                    population = range(
                        0, len(collection["labels"][pattern_key][sensor_key]["idxs"]))
                    number_of_samples = int(
                        max(0.1, min(args["sample_residuals"], 1.0)) * len(population))
                    samples = rng.sample(population, number_of_samples)

                    # Save it to be used in the objective function.
                    collection["labels"][pattern_key][sensor_key]["samples"] = samples

                    for idx in samples:
                        rname = "c" + collection_key + "_p_" + \
                            pattern_key + '_' + sensor_key + "_oe_" + str(idx)
                        opt.pushResidual(name=rname, params=params)

                    # Extrema (limits) displacement error
                    for idx in range(
                        0, len(
                            collection["labels"][pattern_key][sensor_key]["idxs_limit_points"])):
                        rname = "c" + collection_key + "_" + "p_" + \
                            pattern_key + "_" + sensor_key + "_ld_" + str(idx)
                        opt.pushResidual(name=rname, params=params)

                elif sensor["modality"] == "depth":

                    population = range(
                        0, len(collection["labels"][pattern_key][sensor_key]["idxs"]))
                    number_of_samples = int(
                        max(0.1, min(args["sample_residuals"], 1.0)) * len(population))
                    samples = rng.sample(population, number_of_samples)

                    # Save it to be used in the objective function.
                    collection["labels"][pattern_key][sensor_key]["samples"] = samples

                    for idx in samples:
                        opt.pushResidual(
                            name="c" + collection_key + "_" + "p_" + pattern_key + "_" + sensor_key +
                            "_oe_" + str(idx),
                            params=params)

                    population_longitudinal = range(
                        0, len(collection["labels"][pattern_key][sensor_key]["idxs_limit_points"]))
                    number_of_samples_longitudinal = int(
                        max(0.1, min(args["sample_longitudinal_residuals"], 1.0)) * len(population_longitudinal))
                    samples_longitudinal = rng.sample(
                        population_longitudinal, number_of_samples_longitudinal)

                    # Save it to be used in the objective function.
                    collection["labels"][pattern_key][sensor_key]["samples_longitudinal"] = samples_longitudinal

                    # Extrema displacement error
                    for idx in samples_longitudinal:
                        opt.pushResidual(
                            name="c" + collection_key + "_" + "p_" + pattern_key + "_" + sensor_key +
                            "_ld_" + str(idx),
                            params=params)

                # print(opt.residuals.keys())
                # print('Adding residuals for sensor ' + sensor_key + ' with msg_type ' + sensor['msg_type'] +
                #       ' affected by parameters:\n' + str(params))

    # opt.printResiduals()

    # ---------------------------------------
    # --- Compute the SPARSE MATRIX
    # ---------------------------------------
    print("Computing sparse matrix ... ")
    opt.computeSparseMatrix()
    # opt.printSparseMatrix()
    # opt.printParameters()

    # ---------------------------------------
    # --- Get a normalizer for each residual type
    # ---------------------------------------
    modalities = set([s["modality"] for s in dataset["sensors"].values()])
    normalizer = {k: 1.0 for k in modalities}
    opt.addDataModel("normalizer", normalizer)

    residuals = objectiveFunction(opt.data_models)
    opt.callObjectiveFunction()
    for modality in modalities:
        values = []
        for sensor_key, sensor in dataset["sensors"].items():
            if sensor["modality"] == modality:
                values += [residuals[k]
                           for k in residuals.keys() if sensor_key in k]

        # If the normalizer is to big, the residuals may loose their influence when close to zero.
        if modality == 'rgb':
            normalizer[modality] = args['rgb_normalizer_multiplier'] * np.mean(values)
        else:
            normalizer[modality] = np.mean(values)

        print('Normalizer for ' + str(modality) + ': ' + str(normalizer[modality]))

    # ---------------------------------------
    # --- Start Optimization
    # ---------------------------------------
    print('Optimization setup with ' + Fore.BLUE + str(len(opt.getParamNames())) +
          Style.RESET_ALL + ' parameters: ' + str(opt.getParamNames()))

    # store initial values of dataset before starting optimization
    dataset_initial = deepcopy(dataset)

    if args["phased_execution"]:
        waitForKeyPress2(function=opt.callObjectiveFunction, timeout=0.1)

    print("Initializing optimization ...")
    options = {
        'ftol': args['optimization_ftol'],
        'xtol': args['optimization_xtol'],
        'gtol': args['optimization_gtol'],
        'diff_step': args['optimization_diff_step'],
        'max_nfev': args['optimization_max_nfev'],
        'x_scale': 'jac'}
    opt.startOptimization(optimization_options=options)

    if args["phased_execution"]:
        waitForKeyPress2(opt.callObjectiveFunction, timeout=0.1,
                         message=Fore.BLUE + "Optimization finished" + Style.RESET_ALL)
    # Final error report
    errorReport(dataset=dataset, residuals=objectiveFunction(
        opt.data_models), normalizer=normalizer, args=args)

    # ---------------------------------------
    # --- Print Results
    # ---------------------------------------
    if args['comparison_to_ground_truth']:
        # opt.printParametersErrors() # only if we want to see the value of each parameter
        printComparisonToGroundTruth(dataset, dataset_initial, dataset_ground_truth,
                                     selected_collection_key, output_folder, args)

    # ---------------------------------------
    # --- Save command line args
    # ---------------------------------------
    saveCommandLineArgsYml(args_original, output_folder + '/command_line_args.yml')

    # ---------------------------------------
    # --- Save updated JSON file
    # ---------------------------------------
    # to store all sensors that existed in the dataset, regardless of which were optimized
    # If a sensor_key exists in the original dataset but not in the optimized dataset, then it was removed from the optimization and we should copy it back to the dataset before saving.
    for sensor_key, sensor in dataset_original["sensors"].items():
        if not sensor_key in dataset['sensors']:
            dataset['sensors'][sensor_key] = sensor

    saveAtomDataset(output_folder + '/atom_calibration.json', dataset,
                    save_data_files=True)

    # ---------------------------------------
    # --- Save updated xacro
    # ---------------------------------------
    # TODO #897 now that the transforms can also come from the bagfile, this must be updated tackle this.
    sensors_transforms_list = list(sensors_transforms_set)
    additional_transforms_list = list(additional_transforms_set)
    transforms_list = sensors_transforms_list + additional_transforms_list
    saveResultsXacro(dataset, selected_collection_key, transforms_list, verbose=False)

    # ---------------------------------------
    # --- Save results into yaml
    # ---------------------------------------
    filename_results_yml = output_folder + '/atom_calibration_params.yml'
    saveResultsYml(dataset, selected_collection_key, filename_results_yml)
    filename_results_yml = output_folder + '/original_calibration_params.yml'
    saveResultsYml(dataset_original, selected_collection_key, filename_results_yml)


if __name__ == "__main__":
    main()
