#!/usr/bin/env python3

import os
import sys
import argparse

import cv2
import numpy as np
import atom_core.config_io
import rospy
import tf
from sensor_msgs.msg import Image, CameraInfo
from cv_bridge import CvBridge
from atom_calibration.collect import patterns
from atom_core.geometry import traslationRodriguesToTransform
from tf import transformations


class SimplePatternDetector:

    def __init__(self, args, config):
        # TODO only works for first pattern
        self.pattern_key = args['pattern_name']
        print(config['calibration_patterns'][self.pattern_key]['dimension'])
        d = config['calibration_patterns'][self.pattern_key]['dimension']
        print(config['calibration_patterns'][self.pattern_key]['dimension']['x'])
        size = {'x': config['calibration_patterns'][self.pattern_key]['dimension']['x'],
                'y': config['calibration_patterns'][self.pattern_key]['dimension']['y']}
        length = config['calibration_patterns'][self.pattern_key]['size']
        inner_length = config['calibration_patterns'][self.pattern_key]['inner_size']
        dictionary = config['calibration_patterns'][self.pattern_key]['dictionary']
        self.args = args
        self.config = config

        # TODO only works for first pattern
        if config['calibration_patterns'][self.pattern_key]['pattern_type'] == 'charuco':
            self.pattern = patterns.CharucoPattern(
                size, length, inner_length, dictionary)
        elif config['calibration_patterns'][self.pattern_key]['pattern_type'] == 'chessboard':
            self.pattern = patterns.ChessboardPattern(size, length)
        else:
            rospy.logerr("Unknown pattern '{}'".format(
                config['calibration_patterns'][self.pattern_key]['pattern_type']))
            sys.exit(1)

        # Get a camera_info message
        topic = config['sensors'][args['sensor_name']]['topic_name']
        camera_info_topic = os.path.dirname(topic) + '/camera_info'

        if camera_info_topic is not None:
            print('Waiting for camera_info message on topic ' +
                  camera_info_topic + ' ...')
            self.camera_info_msg = rospy.wait_for_message(
                camera_info_topic, CameraInfo)
            print('... received!')
            self.broadcaster = tf.TransformBroadcaster()
        else:
            self.camera_info_msg = None
            self.broadcaster = None

        self.sub = rospy.Subscriber(
            topic, Image, self.onImageReceived, queue_size=1)
        self.bridge = CvBridge()
        self.image_pub = rospy.Publisher(
            topic + '/labeled', Image, queue_size=1)

    def onImageReceived(self, image_msg):
        if self.args['use_ir']:
            image = self.bridge.imgmsg_to_cv2(image_msg, 'mono16')
            image = image.astype(np.uint8)
        else:
            image = self.bridge.imgmsg_to_cv2(image_msg, 'bgr8')

        # TODO only works for first pattern
        pattern_key = self.args['pattern_name']
        nx = self.config['calibration_patterns'][pattern_key]['dimension']['x']
        ny = self.config['calibration_patterns'][pattern_key]['dimension']['y']
        K = np.ndarray((3, 3), dtype=float,
                       buffer=np.array(self.camera_info_msg.K))
        D = np.ndarray((5, 1), dtype=float,
                       buffer=np.array(self.camera_info_msg.D))

        result = self.pattern.detect(image, equalize_histogram=False)

        if result['detected']:
            print('Pattern detected (' +
                  str(len(result['ids'])) + ' out of ' + str(nx * ny) + ' corners)')
        else:
            print('Failed to detect pattern')

        self.pattern.drawKeypoints(image, result, K, D)

        if self.args['use_ir']:
            image_msg_out = self.bridge.cv2_to_imgmsg(image, 'mono8')
        else:
            image_msg_out = self.bridge.cv2_to_imgmsg(image, 'bgr8')

        self.image_pub.publish(image_msg_out)

        topic = self.config['sensors'][self.args['sensor_name']]['topic_name']
        cv2.namedWindow(topic, cv2.WINDOW_NORMAL)
        cv2.imshow(topic, image)
        key = cv2.waitKey(10)
        if key & 0xff == ord('q'):
            rospy.signal_shutdown(1)


def main():
    rospy.init_node('detect_chessboard', anonymous=True)

    parser = argparse.ArgumentParser()
    parser.add_argument("-cfg", "--config_file",
                        help='Specify if you want to configure the calibration package with a specific configuration file. If this flag is not given, the standard config.yml ill be used.',
                        type=str, required=False, default=None)
    parser.add_argument("-sn", "--sensor_name", help="Sensor to test as named in the config file.", type=str,
                        required=True)
    parser.add_argument("-pn", "--pattern_name", help="Pattern to test as named in the config file.", type=str,
                        required=True)
    parser.add_argument("-uir", "--use_ir", help="Sensor to test as named in the config file.",
                        action="store_true")
    args = vars(parser.parse_args())

    # Read config file
    config = atom_core.config_io.loadConfig(args['config_file'])

    # first_pattern_key = list(config['calibration_patterns'][self.pattern_key].keys())[0]
    # Create detector
    SimplePatternDetector(args, config)
    rospy.spin()


if __name__ == '__main__':
    main()
