Source code for mmhuman3d.core.conventions.keypoints_mapping
from collections import defaultdict
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from mmcv.utils import print_log
from mmhuman3d.core.conventions.keypoints_mapping import (
agora,
coco,
coco_wholebody,
crowdpose,
gta,
h36m,
human_data,
hybrik,
instavariety,
lsp,
mpi_inf_3dhp,
mpii,
openpose,
penn_action,
posetrack,
pw3d,
smpl,
smplx,
)
KEYPOINTS_FACTORY = {
'human_data': human_data.HUMAN_DATA,
'agora': agora.AGORA_KEYPOINTS,
'coco': coco.COCO_KEYPOINTS,
'coco_wholebody': coco_wholebody.COCO_WHOLEBODY_KEYPOINTS,
'crowdpose': crowdpose.CROWDPOSE_KEYPOINTS,
'smplx': smplx.SMPLX_KEYPOINTS,
'smpl': smpl.SMPL_KEYPOINTS,
'smpl_45': smpl.SMPL_45_KEYPOINTS,
'smpl_54': smpl.SMPL_54_KEYPOINTS,
'smpl_49': smpl.SMPL_49_KEYPOINTS,
'smpl_24': smpl.SMPL_24_KEYPOINTS,
'mpi_inf_3dhp': mpi_inf_3dhp.MPI_INF_3DHP_KEYPOINTS,
'mpi_inf_3dhp_test': mpi_inf_3dhp.MPI_INF_3DHP_TEST_KEYPOINTS,
'penn_action': penn_action.PENN_ACTION_KEYPOINTS,
'h36m': h36m.H36M_KEYPOINTS,
'h36m_mmpose': h36m.H36M_KEYPOINTS_MMPOSE,
'pw3d': pw3d.PW3D_KEYPOINTS,
'mpii': mpii.MPII_KEYPOINTS,
'lsp': lsp.LSP_KEYPOINTS,
'posetrack': posetrack.POSETRACK_KEYPOINTS,
'instavariety': instavariety.INSTAVARIETY_KEYPOINTS,
'openpose_25': openpose.OPENPOSE_25_KEYPOINTS,
'openpose_135': openpose.OPENPOSE_135_KEYPOINTS,
'hybrik_29': hybrik.HYBRIK_29_KEYPOINTS,
'hybrik_hp3d': mpi_inf_3dhp.HYBRIK_MPI_INF_3DHP_KEYPOINTS,
'gta': gta.GTA_KEYPOINTS
}
__KEYPOINTS_MAPPING_CACHE__ = defaultdict(dict)
[docs]def convert_kps(
keypoints: Union[np.ndarray, torch.Tensor],
src: str,
dst: str,
approximate: bool = False,
mask: Optional[Union[np.ndarray, torch.Tensor]] = None,
keypoints_factory: dict = KEYPOINTS_FACTORY,
) -> Tuple[Union[np.ndarray, torch.Tensor], Union[np.ndarray, torch.Tensor]]:
"""Convert keypoints following the mapping correspondence between src and
dst keypoints definition. Supported conventions by now: agora, coco, smplx,
smpl, mpi_inf_3dhp, mpi_inf_3dhp_test, h36m, h36m_mmpose, pw3d, mpii, lsp.
Args:
keypoints [Union[np.ndarray, torch.Tensor]]: input keypoints array,
could be (f * n * J * 3/2) or (f * J * 3/2).
You can set keypoints as np.zeros((1, J, 2))
if you only need mask.
src (str): source data type from keypoints_factory.
dst (str): destination data type from keypoints_factory.
approximate (bool): control whether approximate mapping is allowed.
mask (Optional[Union[np.ndarray, torch.Tensor]], optional):
The original mask to mark the existence of the keypoints.
None represents all ones mask.
Defaults to None.
keypoints_factory (dict, optional): A class to store the attributes.
Defaults to keypoints_factory.
Returns:
Tuple[Union[np.ndarray, torch.Tensor], Union[np.ndarray, torch.Tensor]]
: tuple of (out_keypoints, mask). out_keypoints and mask will be of
the same type.
"""
assert keypoints.ndim in {3, 4}
if src == dst:
return keypoints, np.ones((keypoints.shape[-2]))
src_names = keypoints_factory[src.lower()]
dst_names = keypoints_factory[dst.lower()]
extra_dims = keypoints.shape[:-2]
keypoints = keypoints.reshape(-1, len(src_names), keypoints.shape[-1])
if isinstance(keypoints, np.ndarray):
out_keypoints = np.zeros(
(keypoints.shape[0], len(dst_names), keypoints.shape[-1]))
else:
out_keypoints = torch.zeros(
(keypoints.shape[0], len(dst_names), keypoints.shape[-1]),
device=keypoints.device,
dtype=keypoints.dtype)
original_mask = mask
if original_mask is not None:
original_mask = original_mask.reshape(-1)
assert original_mask.shape[0] == len(
src_names), f'The length of mask should be {len(src_names)}'
if isinstance(keypoints, np.ndarray):
mask = np.zeros((len(dst_names)), dtype=np.uint8)
elif isinstance(keypoints, torch.Tensor):
mask = torch.zeros((len(dst_names)),
dtype=torch.uint8,
device=keypoints.device)
else:
raise TypeError('keypoints should be torch.Tensor or np.ndarray')
dst_idxs, src_idxs, _ = \
get_mapping(src, dst, approximate, keypoints_factory)
out_keypoints[:, dst_idxs] = keypoints[:, src_idxs]
out_shape = extra_dims + (len(dst_names), keypoints.shape[-1])
out_keypoints = out_keypoints.reshape(out_shape)
mask[dst_idxs] = original_mask[src_idxs] \
if original_mask is not None else 1.0
return out_keypoints, mask
[docs]def compress_converted_kps(
zero_pad_array: Union[np.ndarray, torch.Tensor],
mask_array: Union[np.ndarray, torch.Tensor],
) -> Union[np.ndarray, torch.Tensor]:
"""Compress keypoints that are zero-padded after applying convert_kps.
Args:
keypoints (np.ndarray): input keypoints array, could be
(f * n * J * 3/2) or (f * J * 3/2). You can set keypoints as
np.zeros((1, J, 2)) if you only need mask.
mask [Union[np.ndarray, torch.Tensor]]:
The original mask to mark the existence of the keypoints.
Returns:
Union[np.ndarray, torch.Tensor]: out_keypoints
"""
assert mask_array.shape[0] == zero_pad_array.shape[1]
valid_mask_index = np.where(mask_array == 1)[0]
compressed_array = np.take(zero_pad_array, valid_mask_index, axis=1)
return compressed_array
[docs]def get_mapping(src: str,
dst: str,
approximate: bool = False,
keypoints_factory: dict = KEYPOINTS_FACTORY):
"""Get mapping list from src to dst.
Args:
src (str): source data type from keypoints_factory.
dst (str): destination data type from keypoints_factory.
approximate (bool): control whether approximate mapping is allowed.
keypoints_factory (dict, optional): A class to store the attributes.
Defaults to keypoints_factory.
Returns:
list:
[src_to_intersection_idx, dst_to_intersection_index,
intersection_names]
"""
if src in __KEYPOINTS_MAPPING_CACHE__ and \
dst in __KEYPOINTS_MAPPING_CACHE__[src] and \
__KEYPOINTS_MAPPING_CACHE__[src][dst][3] == approximate:
return __KEYPOINTS_MAPPING_CACHE__[src][dst][:3]
else:
src_names = keypoints_factory[src.lower()]
dst_names = keypoints_factory[dst.lower()]
dst_idxs, src_idxs, intersection = [], [], []
unmapped_names, approximate_names = [], []
for dst_idx, dst_name in enumerate(dst_names):
matched = False
try:
src_idx = src_names.index(dst_name)
except ValueError:
src_idx = -1
if src_idx >= 0:
matched = True
dst_idxs.append(dst_idx)
src_idxs.append(src_idx)
intersection.append(dst_name)
# approximate mapping
if approximate and not matched:
try:
part_list = human_data.APPROXIMATE_MAP[dst_name]
except KeyError:
continue
for approximate_name in part_list:
try:
src_idx = src_names.index(approximate_name)
except ValueError:
src_idx = -1
if src_idx >= 0:
dst_idxs.append(dst_idx)
src_idxs.append(src_idx)
intersection.append(dst_name)
unmapped_names.append(src_names[src_idx])
approximate_names.append(dst_name)
break
if unmapped_names:
warn_message = \
f'Approximate mapping {unmapped_names}' +\
f' to {approximate_names}'
print_log(msg=warn_message)
mapping_list = [dst_idxs, src_idxs, intersection, approximate]
if src not in __KEYPOINTS_MAPPING_CACHE__:
__KEYPOINTS_MAPPING_CACHE__[src] = {}
__KEYPOINTS_MAPPING_CACHE__[src][dst] = mapping_list
return mapping_list[:3]
[docs]def get_flip_pairs(convention: str = 'smplx',
keypoints_factory: dict = KEYPOINTS_FACTORY) -> List[int]:
"""Get indices of left, right keypoint pairs from specified convention.
Args:
convention (str): data type from keypoints_factory.
keypoints_factory (dict, optional): A class to store the attributes.
Defaults to keypoints_factory.
Returns:
List[int]: left, right keypoint indices
"""
flip_pairs = []
keypoints = keypoints_factory[convention]
left_kps = [kp for kp in keypoints if 'left_' in kp]
for left_kp in left_kps:
right_kp = left_kp.replace('left_', 'right_')
flip_pairs.append([keypoints.index(kp) for kp in [left_kp, right_kp]])
return flip_pairs
[docs]def get_keypoint_idxs_by_part(
part: str,
convention: str = 'smplx',
keypoints_factory: dict = KEYPOINTS_FACTORY) -> List[int]:
"""Get part keypoints indices from specified part and convention.
Args:
part (str): part to search from
convention (str): data type from keypoints_factory.
keypoints_factory (dict, optional): A class to store the attributes.
Defaults to keypoints_factory.
Returns:
List[int]: part keypoint indices
"""
humandata_parts = human_data.HUMAN_DATA_PARTS
keypoints = keypoints_factory[convention]
if part not in humandata_parts.keys():
raise ValueError('part not in allowed parts')
part_keypoints = list(set(humandata_parts[part]) & set(keypoints))
part_keypoints_idx = [keypoints.index(kp) for kp in part_keypoints]
return part_keypoints_idx
[docs]def get_keypoint_idx(name: str,
convention: str = 'smplx',
approximate: bool = False,
keypoints_factory: dict = KEYPOINTS_FACTORY) -> List[int]:
"""Get keypoint index from specified convention with keypoint name.
Args:
name (str): keypoint name
convention (str): data type from keypoints_factory.
approximate (bool): control whether approximate mapping is allowed.
keypoints_factory (dict, optional): A class to store the attributes.
Defaults to keypoints_factory.
Returns:
List[int]: keypoint index
"""
keypoints = keypoints_factory[convention]
try:
idx = keypoints.index(name)
except ValueError:
idx = -1 # not matched
if approximate and idx == -1:
try:
part_list = human_data.APPROXIMATE_MAP[name]
except KeyError:
return idx
for approximate_name in part_list:
try:
idx = keypoints.index(approximate_name)
except ValueError:
idx = -1
if idx >= 0:
return idx
return idx
[docs]def get_keypoint_num(convention: str = 'smplx',
keypoints_factory: dict = KEYPOINTS_FACTORY) -> List[int]:
"""Get number of keypoints of specified convention.
Args:
convention (str): data type from keypoints_factory.
keypoints_factory (dict, optional): A class to store the attributes.
Defaults to keypoints_factory.
Returns:
List[int]: part keypoint indices
"""
keypoints = keypoints_factory[convention]
return len(keypoints)