# Import modules of interest
import cv2
import numpy as np
from primesense import openni2
from primesense import _openni2 as c_api
import time

start_time = time.time()
x = 1 # displays the frame rate every 1 second
counter = 0

# ======= OPENNI2 Paths =======

# print(openni2.enumerateDevices())
class OpenniDevice:
    def __init__(self, uri, window_name):
        self.uri = uri
        redistPath = "/lib/OpenNI2/Redist/"
        # Initialize OpenNI with its libraries
        openni2.initialize(redistPath)  # The OpenNI2 Redist folder
        print("successfully initialized")
        self.dev = openni2.Device(self.uri)
        self.depth_streamer = self.dev.create_depth_stream()
        # self.rgb_streamer = self.dev.create_color_stream()
        self.window_name = window_name
        # self.cam_device = cam_device
        # self.cam = cv2.VideoCapture(cam_device)

    # def open_device(self):
    #     self.dev=o
    #     print("openned sucessfully")

    def create_depth_stream(self):
        # depth_stream =
        if self.dev.has_sensor(openni2.SENSOR_DEPTH):
            self.depth_streamer.start()
            self.depth_streamer.set_video_mode(
                c_api.OniVideoMode(pixelFormat=c_api.OniPixelFormat.ONI_PIXEL_FORMAT_DEPTH_100_UM, resolutionX=640,
                                   resolutionY=480,
                                   fps=30))
            return True
        else:
            return False
        print("started depth stream")

    def create_rgb_stream(self):
        # depth_stream =
        print("trying to set rgb stream")
        print(self.dev.has_sensor(openni2.SENSOR_COLOR))
        if self.dev.has_sensor(openni2.SENSOR_COLOR):
            self.rgb_streamer.start()
            self.rgb_streamer.set_video_mode(
                c_api.OniVideoMode(pixelFormat=c_api.OniPixelFormat.ONI_PIXEL_FORMAT_RGB888, resolutionX=640,
                                   resolutionY=480, fps=30))
            return True
        else:

            return False
            # return False
        # self.rgb_streamer.set_video_mode(
        #     c_api.OniVideoMode(pixelFormat=c_api.OniPixelFormat.ONI_PIXEL_FORMAT_RGB888, resolutionX=640,
        #                        resolutionY=480,
        #                        fps=30))
        #     print("started rgb stream")

    def depth_stream_image(self):
        # while True:
        # print("here")
        frame = self.depth_streamer.read_frame()
        frame_data = frame.get_buffer_as_uint16()
        img = np.frombuffer(frame_data, dtype=np.uint16)
        img.shape = (1, 480, 640)
        img = np.concatenate((img, img, img), axis=0)
        img = np.swapaxes(img, 0, 2)
        img = np.swapaxes(img, 0, 1)
        img = img.astype(np.uint8)  # This is required to be able to draw it
        cv2.imshow(self.window_name + "_depth", img)

    def rgb_stream_image(self):
        # while True:
        # print("here")
        frame = self.rgb_streamer.read_frame()
        bgr = np.fromstring(self.rgb_streamer.read_frame().get_buffer_as_uint8(), dtype=np.uint8).reshape(480, 640, 3)
        img = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
        # frame_data = frame.get_buffer_as_uint8()
        # img = np.frombuffer(frame_data)
        # img.shape = (1, 480, 640)
        # img = np.concatenate((img, img, img), axis=0)
        # img = np.swapaxes(img, 0, 2)
        # img = np.swapaxes(img, 0, 1)
        # img = img.astype(np.uint8)  # This is required to be able to draw it
        cv2.imshow(self.window_name + "_rgb", img)
    # openni2.unload()


# def main():
dev1 = OpenniDevice(uri=b'2bc5/0403@1/12', window_name="camera_1")
has_depth1 = dev1.create_depth_stream()
has_color1 = dev1.create_rgb_stream()

dev2 = OpenniDevice(uri=b'2bc5/0403@1/9', window_name="camera_2")
has_depth2 = dev2.create_depth_stream()
has_color2 = dev2.create_rgb_stream()

# dev3 = OpenniDevice(uri=b'2bc5/0501@1/22', window_name="camera_3")
# has_depth3 = dev3.create_depth_stream()
# has_color3 = dev3.create_rgb_stream()

while True:
    dev1.depth_stream_image()
    dev2.depth_stream_image()

    counter+=1
    if (time.time() - start_time) > x :
        print("FPS: ", counter / (time.time() - start_time))
        counter = 0
        start_time = time.time()

    # if has_color1:
    #     dev1.rgb_stream_image()
    # else:
    #     ret, frame = dev1.cam.read()
    #     if not ret:
    #         break
    #     cv2.imshow("rgb", frame)
    # if has_color2:
    #     dev2.rgb_stream_image()
    # else:
    #     ret, frame = dev2.cam.read()
    #     if not ret:
    #         break
    #     cv2.imshow("rgb", frame)
    # # if has_color3:
    # #     dev3.rgb_stream_image()
    # # else:
    # #     ret, frame = dev3.cam.read()
    # #     if not ret:
    # #         break
    # #     cv2.imshow("rgb", frame)
    #
    # if has_depth1:
    #     dev1.rgb_stream_image()
    # if has_depth2:
    #     dev2.rgb_stream_image()
    # if has_depth3:
    #     dev3.rgb_stream_image()

    cv2.waitKey(34)
openni2.unload()
# device_1=openni2.Device("b'19052130216'")

# Open a device
# dev = openni2.Device.open_any()
# serial_number = str(dev.get_property(c_api.ONI_DEVICE_PROPERTY_SERIAL_NUMBER, (ctypes.c_char * 100)).value)

# dev=device_1.open()
# print(dev.get_device_info())
# print(serial_number)


# d
