import numpy as np
from numpy.linalg import norm


def projectToCamera(intrinsic_matrix, distortion, width, height, pts):
    """
    Projects a list of points to the camera defined transform, intrinsics and distortion
    :param intrinsic_matrix: 3x3 intrinsic camera matrix
    :param distortion: should be as follows: (k_1, k_2, p_1, p_2[, k_3[, k_4, k_5, k_6]])
    :param width: the image width
    :param height: the image height
    :param pts: a list of point coordinates (in the camera frame) with the following format: np array 4xn or 3xn
    :return: a list of pixel coordinates with the same length as pts
    """

    # print('intrinsic_matrix=' + str(intrinsic_matrix))
    # print('distortion=' + str(distortion))
    # print('width=' + str(width))
    # print('height=' + str(height))
    # print('pts.shape=' + str(pts.shape))
    _, n_pts = pts.shape

    # Project the 3D points in the camera's frame to image pixels
    # From https://docs.opencv.org/2.4/modules/calib3d/doc/camera_calibration_and_3d_reconstruction.html
    pixs = np.zeros((2, n_pts), dtype=float)

    if distortion is None:
        k1, k2, p1, p2, k3 = 0, 0, 0, 0, 0
    else:
        k1, k2, p1, p2, k3 = distortion
    # fx, _, cx, _, fy, cy, _, _, _ = intrinsic_matrix
    # print('intrinsic=\n' + str(intrinsic_matrix))
    # print(intrinsic_matrix[0])
    # fx = intrinsic_matrix[0, 0]
    # fy = intrinsic_matrix[1, 1]
    # cx = intrinsic_matrix[0, 2]
    # cy = intrinsic_matrix[1, 2]

    fx = intrinsic_matrix[0][0]
    fy = intrinsic_matrix[1][1]
    cx = intrinsic_matrix[0][2]
    cy = intrinsic_matrix[1][2]

    x = pts[0, :]
    y = pts[1, :]
    z = pts[2, :]

    dists = norm(pts[0:3, :], axis=0)  # compute distances from point to camera
    xl = np.divide(x, z)  # compute homogeneous coordinates
    yl = np.divide(y, z)  # compute homogeneous coordinates
    r2 = xl ** 2 + yl ** 2  # r square (used multiple times bellow)
    xll = xl * (1 + k1 * r2 + k2 * r2 ** 2 + k3 * r2 ** 3) + 2 * p1 * xl * yl + p2 * (r2 + 2 * xl ** 2)
    yll = yl * (1 + k1 * r2 + k2 * r2 ** 2 + k3 * r2 ** 3) + p1 * (r2 + 2 * yl ** 2) + 2 * p2 * xl * yl
    pixs[0, :] = fx * xll + cx
    pixs[1, :] = fy * yll + cy

    # Compute mask of valid projections
    valid_z = z > 0
    valid_xpix = np.logical_and(pixs[0, :] >= 0, pixs[0, :] < width)
    valid_ypix = np.logical_and(pixs[1, :] >= 0, pixs[1, :] < height)
    valid_pixs = np.logical_and(valid_z, np.logical_and(valid_xpix, valid_ypix))
    return pixs, valid_pixs, dists


def projectToCamera_faster(intrinsic_matrix, distortion, width, height, pts):
    """
    Projects a list of points to the camera defined transform, intrinsics and distortion
    :param intrinsic_matrix: 3x3 intrinsic camera matrix
    :param distortion: should be as follows: (k_1, k_2, p_1, p_2[, k_3[, k_4, k_5, k_6]])
    :param width: the image width
    :param height: the image height
    :param pts: a list of point coordinates (in the camera frame) with the following format: np array 4xn or 3xn
    :return: a list of pixel coordinates with the same length as pts
    """

    _, n_pts = pts.shape

    # Extract parameters from intrinsic matrix
    fx = intrinsic_matrix[0, 0]
    fy = intrinsic_matrix[1, 1]
    cx = intrinsic_matrix[0, 2]
    cy = intrinsic_matrix[1, 2]

    # Unpack distortion coefficients
    k1, k2, p1, p2, k3 = distortion[:5] if distortion is not None else (0, 0, 0, 0, 0)

    # Extract point coordinates
    x, y, z = pts[0], pts[1], pts[2]

    # Compute normalized coordinates and square of radius
    xl = x / z
    yl = y / z
    r2 = xl**2 + yl**2

    # Apply distortion model
    xll = xl * (1 + k1 * r2 + k2 * r2**2 + k3 * r2**3) + 2 * p1 * xl * yl + p2 * (r2 + 2 * xl**2)
    yll = yl * (1 + k1 * r2 + k2 * r2**2 + k3 * r2**3) + p1 * (r2 + 2 * yl**2) + 2 * p2 * xl * yl

    # Project to pixel coordinates
    pixs = np.array([[fx * xll + cx], [fy * yll + cy]])

    # Compute mask of valid projections
    valid_z = z > 0
    valid_xpix = (pixs[0, :] >= 0) & (pixs[0, :] < width)
    valid_ypix = (pixs[1, :] >= 0) & (pixs[1, :] < height)
    valid_pixs = valid_z & valid_xpix & valid_ypix

    return pixs, valid_pixs


# def batch_projectToCamera(intrinsic_matrix, distortion, width, height, pts):
#     """
#     Projects a batch of points to the camera defined transform, intrinsics, and distortion
#     :param intrinsic_matrix: 3x3 intrinsic camera matrix
#     :param distortion: should be as follows: (k_1, k_2, p_1, p_2[, k_3[, k_4, k_5, k_6]])
#     :param width: the image width
#     :param height: the image height
#     :param pts: a list of point coordinates (in the camera frame) with the following format: np array 4xn or 3xn
#     :return: a list of pixel coordinates with the same length as pts
#     """
#
#     n_pts = pts.shape[1]
#
#     # Extract parameters from intrinsic matrix
#     fx = intrinsic_matrix[0, 0]
#     fy = intrinsic_matrix[1, 1]
#     cx = intrinsic_matrix[0, 2]
#     cy = intrinsic_matrix[1, 2]
#
#     # Unpack distortion coefficients
#     k1, k2, p1, p2, k3 = distortion[:5] if distortion is not None else (0, 0, 0, 0, 0)
#
#     # Extract point coordinates
#     x, y, z = pts[0], pts[1], pts[2]
#
#     # Compute normalized coordinates and square of radius
#     xl = x / z
#     yl = y / z
#     r2 = xl**2 + yl**2
#
#     # Apply distortion model
#     xll = xl * (1 + k1 * r2 + k2 * r2**2 + k3 * r2**3) + 2 * p1 * xl * yl + p2 * (r2 + 2 * xl**2)
#     yll = yl * (1 + k1 * r2 + k2 * r2**2 + k3 * r2**3) + p1 * (r2 + 2 * yl**2) + 2 * p2 * xl * yl
#
#     # Project to pixel coordinates
#     pixs = np.array([[fx * xll + cx], [fy * yll + cy]])
#
#     return pixs


def batch_projectToCamera(intrinsic_matrix, distortion, width, height, pts):
    """
    Projects a batch of points to the camera defined transform, intrinsics, and distortion
    :param intrinsic_matrix: 3x3 intrinsic camera matrix
    :param distortion: should be as follows: (k_1, k_2, p_1, p_2[, k_3[, k_4, k_5, k_6]])
    :param width: the image width
    :param height: the image height
    :param pts: a list of point coordinates (in the camera frame) with the following format: np array 4xn or 3xn
    :return: a list of pixel coordinates with the same length as pts
    """

    n_pts = pts.shape[1]

    # Extract parameters from intrinsic matrix
    fx = intrinsic_matrix[0, 0]
    fy = intrinsic_matrix[1, 1]
    cx = intrinsic_matrix[0, 2]
    cy = intrinsic_matrix[1, 2]

    # Unpack distortion coefficients
    k1, k2, p1, p2, k3 = distortion[:5] if distortion is not None else (0, 0, 0, 0, 0)

    # Extract point coordinates
    x, y, z = pts[0], pts[1], pts[2]

    # Compute normalized coordinates and square of radius
    xl = x / z
    yl = y / z
    r2 = xl**2 + yl**2

    # Apply distortion model
    xll = xl * (1 + k1 * r2 + k2 * r2**2 + k3 * r2**3) + 2 * p1 * xl * yl + p2 * (r2 + 2 * xl**2)
    yll = yl * (1 + k1 * r2 + k2 * r2**2 + k3 * r2**3) + p1 * (r2 + 2 * yl**2) + 2 * p2 * xl * yl

    # Project to pixel coordinates
    pixs = np.array([[fx * xll + cx], [fy * yll + cy]])

    return pixs