Shortcuts

Source code for mmhuman3d.models.losses.prior_loss

import os
import pickle
import sys

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from mmhuman3d.core.conventions.joints_mapping.standard_joint_angles import (
    STANDARD_JOINT_ANGLE_LIMITS,
    TRANSFORMATION_AA_TO_SJA,
    TRANSFORMATION_SJA_TO_AA,
)
from mmhuman3d.utils.transforms import aa_to_rot6d, aa_to_sja
from ..builder import LOSSES


[docs]@LOSSES.register_module() class ShapePriorLoss(nn.Module): """Prior loss for body shape parameters. Args: reduction (str, optional): The method that reduces the loss to a scalar. Options are "none", "mean" and "sum". loss_weight (float, optional): The weight of the loss. Defaults to 1.0 """ def __init__(self, reduction='mean', loss_weight=1.0): super().__init__() assert reduction in (None, 'none', 'mean', 'sum') self.reduction = reduction self.loss_weight = loss_weight
[docs] def forward(self, betas, loss_weight_override=None, reduction_override=None): """Forward function of loss. Args: betas (torch.Tensor): The body shape parameters loss_weight_override (float, optional): The weight of loss used to override the original weight of loss reduction_override (str, optional): The reduction method used to override the original reduction method of the loss. Defaults to None Returns: torch.Tensor: The calculated loss """ assert reduction_override in (None, 'none', 'mean', 'sum') reduction = ( reduction_override if reduction_override else self.reduction) loss_weight = ( loss_weight_override if loss_weight_override is not None else self.loss_weight) shape_prior_loss = loss_weight * betas**2 if reduction == 'mean': shape_prior_loss = shape_prior_loss.mean() elif reduction == 'sum': shape_prior_loss = shape_prior_loss.sum() return shape_prior_loss
[docs]@LOSSES.register_module() class JointPriorLoss(nn.Module): """Prior loss for joint angles. Args: reduction (str, optional): The method that reduces the loss to a scalar. Options are "none", "mean" and "sum". loss_weight (float, optional): The weight of the loss. Defaults to 1.0 use_full_body (bool, optional): Use full set of joint constraints (in standard joint angles). smooth_spine (bool, optional): Ensuring smooth spine rotations smooth_spine_loss_weight (float, optional): An additional weight factor multiplied on smooth spine loss """ def __init__(self, reduction='mean', loss_weight=1.0, use_full_body=False, smooth_spine=False, smooth_spine_loss_weight=1.0): super().__init__() assert reduction in (None, 'none', 'mean', 'sum') self.reduction = reduction self.loss_weight = loss_weight self.use_full_body = use_full_body self.smooth_spine = smooth_spine self.smooth_spine_loss_weight = smooth_spine_loss_weight if self.use_full_body: self.register_buffer('R_t', TRANSFORMATION_AA_TO_SJA) self.register_buffer('R_t_inv', TRANSFORMATION_SJA_TO_AA) self.register_buffer('sja_limits', STANDARD_JOINT_ANGLE_LIMITS)
[docs] def forward(self, body_pose, loss_weight_override=None, reduction_override=None): """Forward function of loss. Args: body_pose (torch.Tensor): The body pose parameters loss_weight_override (float, optional): The weight of loss used to override the original weight of loss reduction_override (str, optional): The reduction method used to override the original reduction method of the loss. Defaults to None Returns: torch.Tensor: The calculated loss """ assert reduction_override in (None, 'none', 'mean', 'sum') reduction = ( reduction_override if reduction_override else self.reduction) loss_weight = ( loss_weight_override if loss_weight_override is not None else self.loss_weight) if self.use_full_body: batch_size = body_pose.shape[0] body_pose_reshape = body_pose.reshape(batch_size, -1, 3) assert body_pose_reshape.shape[1] in (21, 23) # smpl-x, smpl body_pose_reshape = body_pose_reshape[:, :21, :] body_pose_sja = aa_to_sja(body_pose_reshape, self.R_t, self.R_t_inv) lower_limits = self.sja_limits[:, :, 0] # shape: (21, 3) upper_limits = self.sja_limits[:, :, 1] # shape: (21, 3) lower_loss = (torch.exp(F.relu(lower_limits - body_pose_sja)) - 1).pow(2) upper_loss = (torch.exp(F.relu(body_pose_sja - upper_limits)) - 1).pow(2) standard_joint_angle_prior_loss = (lower_loss + upper_loss).view( body_pose.shape[0], -1) # shape: (n, 3) joint_prior_loss = standard_joint_angle_prior_loss else: # default joint prior loss applied on elbows and knees joint_prior_loss = (torch.exp( body_pose[:, [55, 58, 12, 15]] * torch.tensor([1., -1., -1, -1.], device=body_pose.device)) - 1)**2 if self.smooth_spine: spine1 = body_pose[:, [9, 10, 11]] spine2 = body_pose[:, [18, 19, 20]] spine3 = body_pose[:, [27, 28, 29]] smooth_spine_loss_12 = (torch.exp(F.relu(-spine1 * spine2)) - 1).pow(2) * self.smooth_spine_loss_weight smooth_spine_loss_23 = (torch.exp(F.relu(-spine2 * spine3)) - 1).pow(2) * self.smooth_spine_loss_weight joint_prior_loss = torch.cat( [joint_prior_loss, smooth_spine_loss_12, smooth_spine_loss_23], axis=1) joint_prior_loss = loss_weight * joint_prior_loss if reduction == 'mean': joint_prior_loss = joint_prior_loss.mean() elif reduction == 'sum': joint_prior_loss = joint_prior_loss.sum() return joint_prior_loss
[docs]@LOSSES.register_module() class SmoothJointLoss(nn.Module): """Smooth loss for joint angles. Args: reduction (str, optional): The method that reduces the loss to a scalar. Options are "none", "mean" and "sum". loss_weight (float, optional): The weight of the loss. Defaults to 1.0 degree (bool, optional): The flag which represents whether the input tensor is in degree or radian. """ def __init__(self, reduction='mean', loss_weight=1.0, degree=False): super().__init__() assert reduction in (None, 'none', 'mean', 'sum') self.reduction = reduction self.loss_weight = loss_weight self.degree = degree
[docs] def forward(self, body_pose, loss_weight_override=None, reduction_override=None): """Forward function of loss. Args: body_pose (torch.Tensor): The body pose parameters loss_weight_override (float, optional): The weight of loss used to override the original weight of loss reduction_override (str, optional): The reduction method used to override the original reduction method of the loss. Defaults to None Returns: torch.Tensor: The calculated loss """ assert reduction_override in (None, 'none', 'mean', 'sum') reduction = ( reduction_override if reduction_override else self.reduction) loss_weight = ( loss_weight_override if loss_weight_override is not None else self.loss_weight) theta = body_pose.reshape(body_pose.shape[0], -1, 3) if self.degree: theta = torch.deg2rad(theta) rot_6d = aa_to_rot6d(theta) rot_6d_diff = rot_6d[1:] - rot_6d[:-1] smooth_joint_loss = rot_6d_diff.abs().sum(dim=-1) smooth_joint_loss = torch.cat( [torch.zeros_like(smooth_joint_loss)[:1], smooth_joint_loss]).sum(dim=-1) smooth_joint_loss = loss_weight * smooth_joint_loss if reduction == 'mean': smooth_joint_loss = smooth_joint_loss.mean() elif reduction == 'sum': smooth_joint_loss = smooth_joint_loss.sum() return smooth_joint_loss
[docs]@LOSSES.register_module() class SmoothPelvisLoss(nn.Module): """Smooth loss for pelvis angles. Args: reduction (str, optional): The method that reduces the loss to a scalar. Options are "none", "mean" and "sum". loss_weight (float, optional): The weight of the loss. Defaults to 1.0 degree (bool, optional): The flag which represents whether the input tensor is in degree or radian. """ def __init__(self, reduction='mean', loss_weight=1.0, degree=False): super().__init__() assert reduction in (None, 'none', 'mean', 'sum') self.reduction = reduction self.loss_weight = loss_weight self.degree = degree
[docs] def forward(self, global_orient, loss_weight_override=None, reduction_override=None): """Forward function of loss. Args: global_orient (torch.Tensor): The global orientation parameters loss_weight_override (float, optional): The weight of loss used to override the original weight of loss reduction_override (str, optional): The reduction method used to override the original reduction method of the loss. Defaults to None Returns: torch.Tensor: The calculated loss """ assert reduction_override in (None, 'none', 'mean', 'sum') reduction = ( reduction_override if reduction_override else self.reduction) loss_weight = ( loss_weight_override if loss_weight_override is not None else self.loss_weight) if self.degree: global_orient = torch.deg2rad(global_orient) pelvis = global_orient.unsqueeze(1) rot_6d = aa_to_rot6d(pelvis) rot_6d_diff = rot_6d[1:] - rot_6d[:-1] smooth_pelvis_loss = rot_6d_diff.abs().sum(dim=-1) smooth_pelvis_loss = torch.cat( [torch.zeros_like(smooth_pelvis_loss)[:1], smooth_pelvis_loss]).sum(dim=-1) smooth_pelvis_loss = loss_weight * smooth_pelvis_loss if reduction == 'mean': smooth_pelvis_loss = smooth_pelvis_loss.mean() elif reduction == 'sum': smooth_pelvis_loss = smooth_pelvis_loss.sum() return smooth_pelvis_loss
[docs]@LOSSES.register_module() class SmoothTranslationLoss(nn.Module): """Smooth loss for translations. Args: reduction (str, optional): The method that reduces the loss to a scalar. Options are "none", "mean" and "sum". loss_weight (float, optional): The weight of the loss. Defaults to 1.0 """ def __init__(self, reduction='mean', loss_weight=1.0): super().__init__() assert reduction in (None, 'none', 'mean', 'sum') self.reduction = reduction self.loss_weight = loss_weight
[docs] def forward(self, translation, loss_weight_override=None, reduction_override=None): """Forward function of loss. Args: translation (torch.Tensor): The body translation parameters loss_weight_override (float, optional): The weight of loss used to override the original weight of loss reduction_override (str, optional): The reduction method used to override the original reduction method of the loss. Defaults to None Returns: torch.Tensor: The calculated loss """ assert reduction_override in (None, 'none', 'mean', 'sum') reduction = ( reduction_override if reduction_override else self.reduction) loss_weight = ( loss_weight_override if loss_weight_override is not None else self.loss_weight) translation_diff = translation[1:] - translation[:-1] smooth_translation_loss = translation_diff.abs().sum( dim=-1, keepdim=True) smooth_translation_loss = torch.cat([ torch.zeros_like(smooth_translation_loss)[:1], smooth_translation_loss ]).sum(dim=-1) smooth_translation_loss *= 1e3 smooth_translation_loss = loss_weight * \ smooth_translation_loss if reduction == 'mean': smooth_translation_loss = smooth_translation_loss.mean() elif reduction == 'sum': smooth_translation_loss = smooth_translation_loss.sum() return smooth_translation_loss
[docs]@LOSSES.register_module() class CameraPriorLoss(nn.Module): """Prior loss for predicted camera. Args: reduction (str, optional): The method that reduces the loss to a scalar. Options are "none", "mean" and "sum". scale (float, optional): The scale coefficient for regularizing camera parameters. Defaults to 10 loss_weight (float, optional): The weight of the loss. Defaults to 1.0 """ def __init__(self, scale=10, reduction='mean', loss_weight=1.0): super().__init__() self.scale = scale self.reduction = reduction self.loss_weight = loss_weight
[docs] def forward(self, cameras, loss_weight_override=None, reduction_override=None): """Forward function of loss. Args: cameras (torch.Tensor): The predicted camera parameters loss_weight_override (float, optional): The weight of loss used to override the original weight of loss reduction_override (str, optional): The reduction method used to override the original reduction method of the loss. Defaults to None Returns: torch.Tensor: The calculated loss """ assert reduction_override in (None, 'none', 'mean', 'sum') reduction = ( reduction_override if reduction_override else self.reduction) loss_weight = ( loss_weight_override if loss_weight_override is not None else self.loss_weight) camera_prior_loss = torch.exp(-cameras[:, 0] * self.scale) camera_prior_loss = torch.pow(camera_prior_loss, 2) * loss_weight if reduction == 'mean': camera_prior_loss = camera_prior_loss.mean() elif reduction == 'sum': camera_prior_loss = camera_prior_loss.sum() return camera_prior_loss
[docs]@LOSSES.register_module() class MaxMixturePrior(nn.Module): """Ref: SMPLify-X https://github.com/vchoutas/smplify-x/blob/master/smplifyx/prior.py """ def __init__(self, prior_folder='data', num_gaussians=8, dtype=torch.float32, epsilon=1e-16, use_merged=True, reduction=None, loss_weight=1.0): super(MaxMixturePrior, self).__init__() assert reduction in (None, 'none', 'mean', 'sum') self.reduction = reduction self.loss_weight = loss_weight if dtype == torch.float32: np_dtype = np.float32 elif dtype == torch.float64: np_dtype = np.float64 else: print('Unknown float type {}, exiting!'.format(dtype)) sys.exit(-1) self.num_gaussians = num_gaussians self.epsilon = epsilon self.use_merged = use_merged gmm_fn = 'gmm_{:02d}.pkl'.format(num_gaussians) full_gmm_fn = os.path.join(prior_folder, gmm_fn) if not os.path.exists(full_gmm_fn): print('The path to the mixture prior "{}"'.format(full_gmm_fn) + ' does not exist, exiting!') sys.exit(-1) with open(full_gmm_fn, 'rb') as f: gmm = pickle.load(f, encoding='latin1') if type(gmm) == dict: means = gmm['means'].astype(np_dtype) covs = gmm['covars'].astype(np_dtype) weights = gmm['weights'].astype(np_dtype) elif 'sklearn.mixture.gmm.GMM' in str(type(gmm)): means = gmm.means_.astype(np_dtype) covs = gmm.covars_.astype(np_dtype) weights = gmm.weights_.astype(np_dtype) else: print('Unknown type for the prior: {}, exiting!'.format(type(gmm))) sys.exit(-1) self.register_buffer('means', torch.tensor(means, dtype=dtype)) self.register_buffer('covs', torch.tensor(covs, dtype=dtype)) precisions = [np.linalg.inv(cov) for cov in covs] precisions = np.stack(precisions).astype(np_dtype) self.register_buffer('precisions', torch.tensor(precisions, dtype=dtype)) # The constant term: sqrdets = np.array([(np.sqrt(np.linalg.det(c))) for c in gmm['covars']]) const = (2 * np.pi)**(69 / 2.) nll_weights = np.asarray(gmm['weights'] / (const * (sqrdets / sqrdets.min()))) nll_weights = torch.tensor(nll_weights, dtype=dtype).unsqueeze(dim=0) self.register_buffer('nll_weights', nll_weights) weights = torch.tensor(gmm['weights'], dtype=dtype).unsqueeze(dim=0) self.register_buffer('weights', weights) self.register_buffer('pi_term', torch.log(torch.tensor(2 * np.pi, dtype=dtype))) cov_dets = [ np.log(np.linalg.det(cov.astype(np_dtype)) + epsilon) for cov in covs ] self.register_buffer('cov_dets', torch.tensor(cov_dets, dtype=dtype)) # The dimensionality of the random variable self.random_var_dim = self.means.shape[1]
[docs] def get_mean(self): """Returns the mean of the mixture.""" mean_pose = torch.matmul(self.weights, self.means) return mean_pose
def merged_log_likelihood(self, pose): diff_from_mean = pose.unsqueeze(dim=1) - self.means prec_diff_prod = torch.einsum('mij,bmj->bmi', [self.precisions, diff_from_mean]) diff_prec_quadratic = (prec_diff_prod * diff_from_mean).sum(dim=-1) curr_loglikelihood = 0.5 * diff_prec_quadratic - \ torch.log(self.nll_weights) # curr_loglikelihood = 0.5 * (self.cov_dets.unsqueeze(dim=0) + # self.random_var_dim * self.pi_term + # diff_prec_quadratic # ) - torch.log(self.weights) min_likelihood, _ = torch.min(curr_loglikelihood, dim=1) return min_likelihood
[docs] def log_likelihood(self, pose): """Create graph operation for negative log-likelihood calculation.""" likelihoods = [] for idx in range(self.num_gaussians): mean = self.means[idx] prec = self.precisions[idx] cov = self.covs[idx] diff_from_mean = pose - mean curr_loglikelihood = torch.einsum('bj,ji->bi', [diff_from_mean, prec]) curr_loglikelihood = torch.einsum( 'bi,bi->b', [curr_loglikelihood, diff_from_mean]) cov_term = torch.log(torch.det(cov) + self.epsilon) curr_loglikelihood += 0.5 * ( cov_term + self.random_var_dim * self.pi_term) likelihoods.append(curr_loglikelihood) log_likelihoods = torch.stack(likelihoods, dim=1) min_idx = torch.argmin(log_likelihoods, dim=1) weight_component = self.nll_weights[:, min_idx] weight_component = -torch.log(weight_component) return weight_component + log_likelihoods[:, min_idx]
[docs] def forward(self, body_pose, loss_weight_override=None, reduction_override=None): assert reduction_override in (None, 'none', 'mean', 'sum') reduction = ( reduction_override if reduction_override else self.reduction) loss_weight = ( loss_weight_override if loss_weight_override is not None else self.loss_weight) if self.use_merged: pose_prior_loss = self.merged_log_likelihood(body_pose) else: pose_prior_loss = self.log_likelihood(body_pose) pose_prior_loss = loss_weight * pose_prior_loss if reduction == 'mean': pose_prior_loss = pose_prior_loss.mean() elif reduction == 'sum': pose_prior_loss = pose_prior_loss.sum() return pose_prior_loss