import h5py
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D  # <--- This is important for 3d plotting

# filename = "/home/daniela/catkin_ws/src/hpe/images/human36m/processed/S1/Directions-1/annot.h5"
# filename = "/home/daniela/catkin_ws/src/hpe/images/human36m/extra/una_dinosauria-data/h36m/cameras.h5"
filename = "/home/daniela/catkin_ws/src/hpe/images/human36m/extra/una-dinosauria-data/h36m/S1/MyPoses/3D_positions/Directions 1.h5"

with h5py.File(filename, "r") as f:
    poses3d = f["3D_positions"][:]
    pose3d = poses3d[:,0]
    print(pose3d)







# filename="/home/daniela/catkin_ws/src/hpe/images/human36m/extra/una-dinosauria-data/h36m/cameras.h5"
filename = "/home/daniela/catkin_ws/src/hpe/images/human36m/processed/S1/Directions-1/annot.h5"

with h5py.File(filename, "r") as f:
    # Print all root level object names (aka keys)
    # these can be group or dataset names
    print("Keys: %s" % f.keys())
    # get first object name/key; may or may NOT be a group
    a_group_key = list(f.keys())[5]

    # get the object type for a_group_key: usually group or dataset
    pose3d = f[a_group_key]['3d']
    pose3d_univ = f[a_group_key]['3d-univ']

    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')

    for i in range(0, 500):
        sklt_1 = pose3d[i]/1000
        sklt_2 = pose3d[2 * i]/1000
        sklt_3 = pose3d[3 * i]/1000
        sklt_4 = pose3d[4 * i]/1000

        print("camera_1")
        print(sklt_1)

        # print("camera_2")
        # print(sklt_2)
        #
        # print("camera_3")
        # print(sklt_3)
        #
        # print("camera_4")
        # print(sklt_4)

        i=0

        for point in sklt_1:
            x = point[0]
            y = point[1]
            z = point[2]

            ax.scatter(x, y, z, c='r')
            ax.text(x,y,z,(str(i)))
            i+=1

        # for point in sklt_2:
        #     x = point[0]
        #     y = point[1]
        #     z = point[2]
        #
        #     ax.scatter(x, y, z, c='g')
        #
        # for point in sklt_3:
        #     x = point[0]
        #     y = point[1]
        #     z = point[2]
        #
        #     ax.scatter(x, y, z, c='b')
        #
        # for point in sklt_4:
        #     x = point[0]
        #     y = point[1]
        #     z = point[2]
        #
        #     ax.scatter(x, y, z, c='k')

        plt.show()
        exit(0)

    # print(pose3d[0]/1000)
    # print(pose3d[500]/1000)
    # print(list(f[a_group_key]['camera1']))
    # print(list(f[a_group_key]['camera1']['k']))
    # print(f[a_group_key]['54138969'][()])
    # print(list(f[a_group_key]['3d-univ']))
    # If a_group_key is a group name,
    # this gets the object names in the group and returns as a list
    # data = list(f[a_group_key])

    # If a_group_key is a dataset name,
    # this gets the dataset values and returns as a list
    # data = list(f[a_group_key])
    # preferred methods to get dataset values:
    # ds_obj = f[a_group_key]      # returns as a h5py dataset object
    # ds_arr = f[a_group_key][()]  # returns as a numpy array
    # print(ds_obj['camera1'].keys())

# import torch
# import numpy as np
# import cv2
# from tqdm import tqdm
#
# import os, sys
#
# from learnable-triangulation-pytorch.mvn.datasets.human36m import Human36MMultiViewDataset
#
# h36m_root = "/home/daniela/catkin_ws/src/hpe/images/human36m/processed"
# # labels_multiview_npy_path = sys.argv[2]
# # number_of_processes = int(sys.argv[3])
#
#
# dataset = Human36MMultiViewDataset(
#     h36m_root,
#     # labels_multiview_npy_path,
#     train=True,                       # include all possible data
#     test=True,
#     image_shape=None,                 # don't resize
#     retain_every_n_frames_in_test=1,  # yes actually ALL possible data
#     with_damaged_actions=True,        # I said ALL DATA
#     kind="mpii",
#     norm_image=False,                 # don't do unnecessary image processing
#     crop=False)                       # don't crop
# print("Dataset length:", len(dataset))
