Source code for kimodo.constraints

# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Constraint sets for conditioning motion generation (root 2D, full body, end-effectors)."""

from typing import Optional, Union

import torch
from torch import Tensor

from kimodo.motion_rep.feature_utils import compute_heading_angle
from kimodo.skeleton import SkeletonBase, SOMASkeleton30, SOMASkeleton77
from kimodo.tools import ensure_batched, load_json, save_json

from .geometry import axis_angle_to_matrix, matrix_to_axis_angle


def _convert_constraint_local_rots_to_skeleton(local_rot_mats: Tensor, skeleton: SkeletonBase) -> Tensor:
    """Convert loaded local rotation matrices to match the skeleton's joint count.

    Handles SOMA 30↔77: constraint files may have been saved with 30 or 77 joints while the session
    skeleton (e.g. from the SOMA30 model) uses SOMASkeleton77.
    """
    n_joints = local_rot_mats.shape[-3]
    skeleton_joints = skeleton.nbjoints
    if n_joints == skeleton_joints:
        return local_rot_mats
    if n_joints == 77 and skeleton_joints == 30 and isinstance(skeleton, SOMASkeleton30):
        return skeleton.from_SOMASkeleton77(local_rot_mats)
    if n_joints == 30 and skeleton_joints == 77 and isinstance(skeleton, SOMASkeleton77):
        skel30 = SOMASkeleton30()
        return skel30.to_SOMASkeleton77(local_rot_mats)
    raise ValueError(
        f"Constraint joint count ({n_joints}) does not match skeleton joint count "
        f"({skeleton_joints}). Only SOMA 30↔77 conversion is supported."
    )


[docs] def create_pairs(tensor_A: Tensor, tensor_B: Tensor) -> Tensor: """Form all (a, b) pairs from two 1D tensors; output shape (len(A)*len(B), 2).""" pairs = torch.stack( ( tensor_A[:, None].expand(-1, len(tensor_B)), tensor_B.expand(len(tensor_A), -1), ), dim=-1, ).reshape(-1, 2) return pairs
[docs] def compute_global_heading(global_joints_positions: Tensor, skeleton: SkeletonBase) -> Tensor: """Compute global root heading (cos, sin) from global joint positions using skeleton.""" root_heading_angle = compute_heading_angle(global_joints_positions, skeleton) global_root_heading = torch.stack([torch.cos(root_heading_angle), torch.sin(root_heading_angle)], dim=-1) return global_root_heading
def _tensor_to( t: Tensor, device: Optional[Union[str, torch.device]] = None, dtype: Optional[torch.dtype] = None, ) -> Tensor: """Move tensor to device and/or dtype. Returns same tensor if no args. """ if device is not None and dtype is not None: return t.to(device=device, dtype=dtype) if device is not None: return t.to(device=device) if dtype is not None: return t.to(dtype=dtype) return t
[docs] class Root2DConstraintSet: """Constraint set fixing root (x, z) trajectory and optionally global heading on given frames.""" name = "root2d"
[docs] def __init__( self, skeleton: SkeletonBase, frame_indices: Tensor, smooth_root_2d: Tensor, to_crop: bool = False, global_root_heading: Optional[Tensor] = None, ) -> None: self.skeleton = skeleton # if we pass the full smooth root 3D as input if smooth_root_2d.shape[-1] == 3: smooth_root_2d = smooth_root_2d[..., [0, 1]] if to_crop: smooth_root_2d = smooth_root_2d[frame_indices] if global_root_heading is not None: global_root_heading = global_root_heading[frame_indices] else: assert len(smooth_root_2d) == len( frame_indices ), "The number of smooth root 2d should be match the number of frames" if global_root_heading is not None: assert len(global_root_heading) == len( frame_indices ), "The number of global root heading should be match the number of frames" self.smooth_root_2d = smooth_root_2d self.global_root_heading = global_root_heading self.frame_indices = frame_indices
[docs] def update_constraints(self, data_dict: dict, index_dict: dict) -> None: """Append this constraint's smooth_root_2d (and optional global_root_heading) to data/index dicts.""" data_dict["smooth_root_2d"].append(self.smooth_root_2d) index_dict["smooth_root_2d"].append(self.frame_indices) if self.global_root_heading is not None: # constraint the global heading data_dict["global_root_heading"].append(self.global_root_heading) index_dict["global_root_heading"].append(self.frame_indices)
[docs] def crop_move(self, start: int, end: int) -> "Root2DConstraintSet": """Return a new constraint set for the cropped frame range [start, end).""" mask = (self.frame_indices >= start) & (self.frame_indices < end) if self.global_root_heading is not None: masked_global_root_heading = self.global_root_heading[mask] else: masked_global_root_heading = None return Root2DConstraintSet( self.skeleton, self.frame_indices[mask] - start, self.smooth_root_2d[mask], global_root_heading=masked_global_root_heading, )
[docs] def get_save_info(self) -> dict: """Return a dict suitable for JSON serialization (frame_indices, smooth_root_2d, optional global_root_heading).""" out = { "type": self.name, "frame_indices": self.frame_indices, "smooth_root_2d": self.smooth_root_2d, } if self.global_root_heading is not None: out["global_root_heading"] = self.global_root_heading return out
[docs] def to( self, device: Optional[Union[str, torch.device]] = None, dtype: Optional[torch.dtype] = None, ) -> "Root2DConstraintSet": self.smooth_root_2d = _tensor_to(self.smooth_root_2d, device, dtype) self.frame_indices = _tensor_to(self.frame_indices, device, dtype) if self.global_root_heading is not None: self.global_root_heading = _tensor_to(self.global_root_heading, device, dtype) if device is not None and hasattr(self.skeleton, "to"): self.skeleton = self.skeleton.to(device) return self
[docs] @classmethod def from_dict(cls, skeleton: SkeletonBase, dico: dict) -> "Root2DConstraintSet": """Build a Root2DConstraintSet from a dict (e.g. loaded from JSON).""" device = skeleton.device if hasattr(skeleton, "device") else "cpu" if "global_root_heading" in dico: global_root_heading = torch.tensor(dico["global_root_heading"], device=device) else: global_root_heading = None return cls( skeleton, frame_indices=torch.tensor(dico["frame_indices"]), smooth_root_2d=torch.tensor(dico["smooth_root_2d"], device=device), global_root_heading=global_root_heading, )
[docs] class FullBodyConstraintSet: """Constraint set fixing full-body global positions and rotations on given keyframes.""" name = "fullbody"
[docs] def __init__( self, skeleton: SkeletonBase, frame_indices: Tensor, global_joints_positions: Tensor, global_joints_rots: Tensor, smooth_root_2d: Optional[Tensor] = None, to_crop: bool = False, ): self.skeleton = skeleton self.frame_indices = frame_indices # if we pass the full smooth root 3D as input if smooth_root_2d is not None and smooth_root_2d.shape[-1] == 3: smooth_root_2d = smooth_root_2d[..., [0, 1]] if to_crop: global_joints_positions = global_joints_positions[frame_indices] global_joints_rots = global_joints_rots[frame_indices] if smooth_root_2d is not None: smooth_root_2d = smooth_root_2d[frame_indices] else: assert len(global_joints_positions) == len( frame_indices ), "The number of global positions should be match the number of frames" assert len(global_joints_rots) == len( frame_indices ), "The number of global joint rotations should be match the number of frames" if smooth_root_2d is not None: assert len(smooth_root_2d) == len( frame_indices ), "The number of smooth root 2d (if specified) should be match the number of frames" if smooth_root_2d is None: # substitute the smooth root 2d with the real root smooth_root_2d = global_joints_positions[:, skeleton.root_idx, [0, 2]] # root y: from smooth or pelvis is the same self.root_y_pos = global_joints_positions[:, skeleton.root_idx, 1] self.global_joints_positions = global_joints_positions self.global_joints_rots = global_joints_rots self.global_root_heading = compute_global_heading(global_joints_positions, skeleton) self.smooth_root_2d = smooth_root_2d
[docs] def update_constraints(self, data_dict: dict, index_dict: dict) -> None: """Append global positions, smooth root 2D, root y, and global heading to data/index dicts.""" nbjoints = self.skeleton.nbjoints indices_lst = create_pairs( self.frame_indices, torch.arange(nbjoints, device=self.frame_indices.device), ) data_dict["global_joints_positions"].append( self.global_joints_positions.reshape(-1, 3) ) # flatten the global positions index_dict["global_joints_positions"].append(indices_lst) # global rotations are not used here # as we use smooth root, also constraint the smooth root to get the same full body # maybe keep storing the hips offset, if we smooth it ourselves data_dict["smooth_root_2d"].append(self.smooth_root_2d) index_dict["smooth_root_2d"].append(self.frame_indices) # constraint the y pos of the root data_dict["root_y_pos"].append(self.root_y_pos) index_dict["root_y_pos"].append(self.frame_indices) # constraint the global heading data_dict["global_root_heading"].append(self.global_root_heading) index_dict["global_root_heading"].append(self.frame_indices)
[docs] def crop_move(self, start: int, end: int) -> "FullBodyConstraintSet": """Return a new FullBodyConstraintSet for the cropped frame range [start, end).""" mask = (self.frame_indices >= start) & (self.frame_indices < end) return FullBodyConstraintSet( self.skeleton, self.frame_indices[mask] - start, self.global_joints_positions[mask], self.global_joints_rots[mask], self.smooth_root_2d[mask], )
[docs] def get_save_info(self) -> dict: """Return a dict for JSON save: type, frame_indices, local_joints_rot, root_positions, smooth_root_2d.""" local_joints_rot = self.skeleton.global_rots_to_local_rots(self.global_joints_rots) if isinstance(self.skeleton, SOMASkeleton30): local_joints_rot = self.skeleton.to_SOMASkeleton77(local_joints_rot) local_joints_rot = matrix_to_axis_angle(local_joints_rot) root_positions = self.global_joints_positions[:, self.skeleton.root_idx] return { "type": self.name, "frame_indices": self.frame_indices, "local_joints_rot": local_joints_rot, "root_positions": root_positions, "smooth_root_2d": self.smooth_root_2d, }
[docs] def to( self, device: Optional[Union[str, torch.device]] = None, dtype: Optional[torch.dtype] = None, ) -> "FullBodyConstraintSet": self.frame_indices = _tensor_to(self.frame_indices, device, dtype) self.global_joints_positions = _tensor_to(self.global_joints_positions, device, dtype) self.global_joints_rots = _tensor_to(self.global_joints_rots, device, dtype) self.root_y_pos = _tensor_to(self.root_y_pos, device, dtype) self.global_root_heading = _tensor_to(self.global_root_heading, device, dtype) self.smooth_root_2d = _tensor_to(self.smooth_root_2d, device, dtype) if device is not None and hasattr(self.skeleton, "to"): self.skeleton = self.skeleton.to(device) return self
[docs] @classmethod def from_dict(cls, skeleton: SkeletonBase, dico: dict) -> "FullBodyConstraintSet": """Build a FullBodyConstraintSet from a dict (e.g. loaded from JSON).""" frame_indices = torch.tensor(dico["frame_indices"]) device = skeleton.device if hasattr(skeleton, "device") else "cpu" local_rot = torch.tensor(dico["local_joints_rot"], device=device) local_rot_mats = axis_angle_to_matrix(local_rot) local_rot_mats = _convert_constraint_local_rots_to_skeleton(local_rot_mats, skeleton) global_joints_rots, global_joints_positions, _ = skeleton.fk( local_rot_mats, torch.tensor(dico["root_positions"], device=device), ) smooth_root_2d = None if "smooth_root_2d" in dico: smooth_root_2d = torch.tensor(dico["smooth_root_2d"], device=device) return cls( skeleton, frame_indices=frame_indices, global_joints_positions=global_joints_positions, global_joints_rots=global_joints_rots, smooth_root_2d=smooth_root_2d, )
[docs] class EndEffectorConstraintSet: """Constraint set fixing selected end-effector positions and rotations on given frames.""" name = "end-effector"
[docs] def __init__( self, skeleton: SkeletonBase, frame_indices: Tensor, global_joints_positions: Tensor, global_joints_rots: Tensor, smooth_root_2d: Optional[Tensor], *, joint_names: list[str], to_crop: bool = False, ) -> None: self.skeleton = skeleton self.frame_indices = frame_indices self.joint_names = joint_names # joint_names are constant for all the frames rot_joint_names, pos_joint_names = self.skeleton.expand_joint_names(self.joint_names) # indexing works for motion_rep with smooth root only (contains pelvis index) self.pos_indices = torch.tensor([self.skeleton.bone_index[jname] for jname in pos_joint_names]) self.rot_indices = torch.tensor([self.skeleton.bone_index[jname] for jname in rot_joint_names]) # if we pass the full smooth root 3D as input if smooth_root_2d is not None and smooth_root_2d.shape[-1] == 3: smooth_root_2d = smooth_root_2d[..., [0, 1]] if to_crop: global_joints_positions = global_joints_positions[frame_indices] global_joints_rots = global_joints_rots[frame_indices] if smooth_root_2d is not None: smooth_root_2d = smooth_root_2d[frame_indices] else: assert len(global_joints_positions) == len( frame_indices ), "The number of global positions should be match the number of frames" assert len(global_joints_rots) == len( frame_indices ), "The number of global joint rotations should be match the number of frames" if smooth_root_2d is not None: assert len(smooth_root_2d) == len( frame_indices ), "The number of smooth root 2d (if specified) should be match the number of frames" if smooth_root_2d is None: # substitute the smooth root 2d with the real root smooth_root_2d = global_joints_positions[:, skeleton.root_idx, [0, 2]] # root y: from smooth or pelvis is the same self.root_y_pos = global_joints_positions[:, skeleton.root_idx, 1] self.global_joints_positions = global_joints_positions self.global_root_heading = compute_global_heading(global_joints_positions, skeleton) self.global_joints_rots = global_joints_rots self.smooth_root_2d = smooth_root_2d
[docs] def update_constraints(self, data_dict: dict, index_dict: dict) -> None: """Append constrained joint positions/rots, smooth root 2D, root y, and heading to data/index dicts.""" crop_frames_indexing = torch.arange(len(self.frame_indices), device=self.frame_indices.device) # constraint positions pos_indices_real = create_pairs( self.frame_indices, self.pos_indices, ) pos_indices_crop = create_pairs( crop_frames_indexing, self.pos_indices, ) data_dict["global_joints_positions"].append(self.global_joints_positions[tuple(pos_indices_crop.T)]) index_dict["global_joints_positions"].append(pos_indices_real) # constraint rotations rot_indices_real = create_pairs( self.frame_indices, self.rot_indices, ) rot_indices_crop = create_pairs( crop_frames_indexing, self.rot_indices, ) data_dict["global_joints_rots"].append(self.global_joints_rots[tuple(rot_indices_crop.T)]) index_dict["global_joints_rots"].append(rot_indices_real) # as we use smooth root, also constraint the smooth root to get the same full body # maybe keep storing the hips offset, if we smooth it ourselves data_dict["smooth_root_2d"].append(self.smooth_root_2d) index_dict["smooth_root_2d"].append(self.frame_indices) # constraint the y pos of the root data_dict["root_y_pos"].append(self.root_y_pos) index_dict["root_y_pos"].append(self.frame_indices) # constraint the global heading data_dict["global_root_heading"].append(self.global_root_heading) index_dict["global_root_heading"].append(self.frame_indices)
[docs] def crop_move(self, start: int, end: int) -> "EndEffectorConstraintSet": """Return a new EndEffectorConstraintSet for the cropped frame range [start, end).""" mask = (self.frame_indices >= start) & (self.frame_indices < end) cls = type(self) kwargs = {} if not hasattr(cls, "joint_names"): kwargs["joint_names"] = self.joint_names return cls( self.skeleton, self.frame_indices[mask] - start, self.global_joints_positions[mask], self.global_joints_rots[mask], self.smooth_root_2d[mask], **kwargs, )
[docs] def get_save_info(self) -> dict: """Return a dict for JSON save: type, frame_indices, local_joints_rot, root_positions, smooth_root_2d, joint_names.""" local_joints_rot = self.skeleton.global_rots_to_local_rots(self.global_joints_rots) if isinstance(self.skeleton, SOMASkeleton30): local_joints_rot = self.skeleton.to_SOMASkeleton77(local_joints_rot) local_joints_rot = matrix_to_axis_angle(local_joints_rot) root_positions = self.global_joints_positions[:, self.skeleton.root_idx] output = { "type": self.name, "frame_indices": self.frame_indices, "local_joints_rot": local_joints_rot, "root_positions": root_positions, "smooth_root_2d": self.smooth_root_2d, } if not hasattr(self.__class__, "joint_names"): # save the joint_names for this base class # but not for children output["joint_names"] = self.joint_names return output
[docs] def to( self, device: Optional[Union[str, torch.device]] = None, dtype: Optional[torch.dtype] = None, ) -> "EndEffectorConstraintSet": self.frame_indices = _tensor_to(self.frame_indices, device, dtype) self.pos_indices = _tensor_to(self.pos_indices, device, dtype) self.rot_indices = _tensor_to(self.rot_indices, device, dtype) self.root_y_pos = _tensor_to(self.root_y_pos, device, dtype) self.global_joints_positions = _tensor_to(self.global_joints_positions, device, dtype) self.global_root_heading = _tensor_to(self.global_root_heading, device, dtype) self.global_joints_rots = _tensor_to(self.global_joints_rots, device, dtype) self.smooth_root_2d = _tensor_to(self.smooth_root_2d, device, dtype) if device is not None and hasattr(self.skeleton, "to"): self.skeleton = self.skeleton.to(device) return self
[docs] @classmethod def from_dict(cls, skeleton: SkeletonBase, dico: dict) -> "EndEffectorConstraintSet": """Build an EndEffectorConstraintSet from a dict (e.g. loaded from JSON).""" frame_indices = torch.tensor(dico["frame_indices"]) device = skeleton.device if hasattr(skeleton, "device") else "cpu" local_rot = torch.tensor(dico["local_joints_rot"], device=device) local_rot_mats = axis_angle_to_matrix(local_rot) local_rot_mats = _convert_constraint_local_rots_to_skeleton(local_rot_mats, skeleton) global_joints_rots, global_joints_positions, _ = skeleton.fk( local_rot_mats, torch.tensor(dico["root_positions"], device=device), ) smooth_root_2d = None if "smooth_root_2d" in dico: smooth_root_2d = torch.tensor(dico["smooth_root_2d"], device=device) kwargs = {} if not hasattr(cls, "joint_names"): kwargs["joint_names"] = dico["joint_names"] return cls( skeleton, frame_indices=frame_indices, global_joints_positions=global_joints_positions, global_joints_rots=global_joints_rots, smooth_root_2d=smooth_root_2d, **kwargs, )
[docs] class LeftHandConstraintSet(EndEffectorConstraintSet): """End-effector constraint for the left hand only.""" name = "left-hand" joint_names: list[str] = ["LeftHand"]
[docs] def __init__(self, *args, **kwargs: dict): super().__init__(*args, joint_names=self.joint_names, **kwargs)
[docs] class RightHandConstraintSet(EndEffectorConstraintSet): """End-effector constraint for the right hand only.""" name = "right-hand" joint_names: list[str] = ["RightHand"]
[docs] def __init__(self, *args, **kwargs: dict): super().__init__(*args, joint_names=self.joint_names, **kwargs)
[docs] class LeftFootConstraintSet(EndEffectorConstraintSet): """End-effector constraint for the left foot only.""" name = "left-foot" joint_names: list[str] = ["LeftFoot"]
[docs] def __init__(self, *args, **kwargs: dict): super().__init__(*args, joint_names=self.joint_names, **kwargs)
[docs] class RightFootConstraintSet(EndEffectorConstraintSet): """End-effector constraint for the right foot only.""" name = "right-foot" joint_names: list[str] = ["RightFoot"]
[docs] def __init__(self, *args, **kwargs: dict): super().__init__(*args, joint_names=self.joint_names, **kwargs)
TYPE_TO_CLASS = { "root2d": Root2DConstraintSet, "fullbody": FullBodyConstraintSet, "left-hand": LeftHandConstraintSet, "right-hand": RightHandConstraintSet, "left-foot": LeftFootConstraintSet, "right-foot": RightFootConstraintSet, "end-effector": EndEffectorConstraintSet, }
[docs] def load_constraints_lst( path_or_data: str | list, skeleton: SkeletonBase, device: Optional[Union[str, torch.device]] = None, dtype: Optional[torch.dtype] = None, ): """Load a list of constraints from JSON path or list of dicts. Args: path_or_data: Path to constraints.json or list of constraint dicts. skeleton: Skeleton instance (used for from_dict). device: If set, move all constraint tensors and skeleton to this device. dtype: If set, cast constraint tensors to this dtype. """ if isinstance(path_or_data, str): saved = load_json(path_or_data) else: saved = path_or_data constraints_lst = [] for el in saved: cls = TYPE_TO_CLASS[el["type"]] c = cls.from_dict(skeleton, el) if device is not None or dtype is not None: c.to(device=device, dtype=dtype) constraints_lst.append(c) return constraints_lst
[docs] def save_constraints_lst(path: str, constraints_lst: list) -> list | None: """Save a list of constraint sets to a JSON file. Returns None if list is empty. """ if not constraints_lst: print("The constraints lst is empty. Skip saving") return to_save = [] def tensor_to_list(obj): """Recursively convert tensors to lists for JSON serialization.""" if isinstance(obj, Tensor): return obj.cpu().tolist() elif isinstance(obj, dict): return {k: tensor_to_list(v) for k, v in obj.items()} elif isinstance(obj, list): return [tensor_to_list(v) for v in obj] else: return obj for constraint in constraints_lst: constraint_info = constraint.get_save_info() # Convert all tensors to lists for JSON serialization constraint_info = tensor_to_list(constraint_info) to_save.append(constraint_info) save_json(path, to_save) print(f"Saved constraints to {path}") return to_save