# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Constraint sets for conditioning motion generation (root 2D, full body, end-effectors)."""
from typing import Optional, Union
import torch
from torch import Tensor
from kimodo.motion_rep.feature_utils import compute_heading_angle
from kimodo.skeleton import SkeletonBase, SOMASkeleton30, SOMASkeleton77
from kimodo.tools import ensure_batched, load_json, save_json
from .geometry import axis_angle_to_matrix, matrix_to_axis_angle
def _convert_constraint_local_rots_to_skeleton(local_rot_mats: Tensor, skeleton: SkeletonBase) -> Tensor:
"""Convert loaded local rotation matrices to match the skeleton's joint count.
Handles SOMA 30↔77: constraint files may have been saved with 30 or 77 joints while the session
skeleton (e.g. from the SOMA30 model) uses SOMASkeleton77.
"""
n_joints = local_rot_mats.shape[-3]
skeleton_joints = skeleton.nbjoints
if n_joints == skeleton_joints:
return local_rot_mats
if n_joints == 77 and skeleton_joints == 30 and isinstance(skeleton, SOMASkeleton30):
return skeleton.from_SOMASkeleton77(local_rot_mats)
if n_joints == 30 and skeleton_joints == 77 and isinstance(skeleton, SOMASkeleton77):
skel30 = SOMASkeleton30()
return skel30.to_SOMASkeleton77(local_rot_mats)
raise ValueError(
f"Constraint joint count ({n_joints}) does not match skeleton joint count "
f"({skeleton_joints}). Only SOMA 30↔77 conversion is supported."
)
[docs]
def create_pairs(tensor_A: Tensor, tensor_B: Tensor) -> Tensor:
"""Form all (a, b) pairs from two 1D tensors; output shape (len(A)*len(B), 2)."""
pairs = torch.stack(
(
tensor_A[:, None].expand(-1, len(tensor_B)),
tensor_B.expand(len(tensor_A), -1),
),
dim=-1,
).reshape(-1, 2)
return pairs
[docs]
def compute_global_heading(global_joints_positions: Tensor, skeleton: SkeletonBase) -> Tensor:
"""Compute global root heading (cos, sin) from global joint positions using skeleton."""
root_heading_angle = compute_heading_angle(global_joints_positions, skeleton)
global_root_heading = torch.stack([torch.cos(root_heading_angle), torch.sin(root_heading_angle)], dim=-1)
return global_root_heading
def _tensor_to(
t: Tensor,
device: Optional[Union[str, torch.device]] = None,
dtype: Optional[torch.dtype] = None,
) -> Tensor:
"""Move tensor to device and/or dtype.
Returns same tensor if no args.
"""
if device is not None and dtype is not None:
return t.to(device=device, dtype=dtype)
if device is not None:
return t.to(device=device)
if dtype is not None:
return t.to(dtype=dtype)
return t
[docs]
class Root2DConstraintSet:
"""Constraint set fixing root (x, z) trajectory and optionally global heading on given
frames."""
name = "root2d"
[docs]
def __init__(
self,
skeleton: SkeletonBase,
frame_indices: Tensor,
smooth_root_2d: Tensor,
to_crop: bool = False,
global_root_heading: Optional[Tensor] = None,
) -> None:
self.skeleton = skeleton
# if we pass the full smooth root 3D as input
if smooth_root_2d.shape[-1] == 3:
smooth_root_2d = smooth_root_2d[..., [0, 1]]
if to_crop:
smooth_root_2d = smooth_root_2d[frame_indices]
if global_root_heading is not None:
global_root_heading = global_root_heading[frame_indices]
else:
assert len(smooth_root_2d) == len(
frame_indices
), "The number of smooth root 2d should be match the number of frames"
if global_root_heading is not None:
assert len(global_root_heading) == len(
frame_indices
), "The number of global root heading should be match the number of frames"
self.smooth_root_2d = smooth_root_2d
self.global_root_heading = global_root_heading
self.frame_indices = frame_indices
[docs]
def update_constraints(self, data_dict: dict, index_dict: dict) -> None:
"""Append this constraint's smooth_root_2d (and optional global_root_heading) to data/index
dicts."""
data_dict["smooth_root_2d"].append(self.smooth_root_2d)
index_dict["smooth_root_2d"].append(self.frame_indices)
if self.global_root_heading is not None:
# constraint the global heading
data_dict["global_root_heading"].append(self.global_root_heading)
index_dict["global_root_heading"].append(self.frame_indices)
[docs]
def crop_move(self, start: int, end: int) -> "Root2DConstraintSet":
"""Return a new constraint set for the cropped frame range [start, end)."""
mask = (self.frame_indices >= start) & (self.frame_indices < end)
if self.global_root_heading is not None:
masked_global_root_heading = self.global_root_heading[mask]
else:
masked_global_root_heading = None
return Root2DConstraintSet(
self.skeleton,
self.frame_indices[mask] - start,
self.smooth_root_2d[mask],
global_root_heading=masked_global_root_heading,
)
[docs]
def get_save_info(self) -> dict:
"""Return a dict suitable for JSON serialization (frame_indices, smooth_root_2d, optional
global_root_heading)."""
out = {
"type": self.name,
"frame_indices": self.frame_indices,
"smooth_root_2d": self.smooth_root_2d,
}
if self.global_root_heading is not None:
out["global_root_heading"] = self.global_root_heading
return out
[docs]
def to(
self,
device: Optional[Union[str, torch.device]] = None,
dtype: Optional[torch.dtype] = None,
) -> "Root2DConstraintSet":
self.smooth_root_2d = _tensor_to(self.smooth_root_2d, device, dtype)
self.frame_indices = _tensor_to(self.frame_indices, device, dtype)
if self.global_root_heading is not None:
self.global_root_heading = _tensor_to(self.global_root_heading, device, dtype)
if device is not None and hasattr(self.skeleton, "to"):
self.skeleton = self.skeleton.to(device)
return self
[docs]
@classmethod
def from_dict(cls, skeleton: SkeletonBase, dico: dict) -> "Root2DConstraintSet":
"""Build a Root2DConstraintSet from a dict (e.g. loaded from JSON)."""
device = skeleton.device if hasattr(skeleton, "device") else "cpu"
if "global_root_heading" in dico:
global_root_heading = torch.tensor(dico["global_root_heading"], device=device)
else:
global_root_heading = None
return cls(
skeleton,
frame_indices=torch.tensor(dico["frame_indices"]),
smooth_root_2d=torch.tensor(dico["smooth_root_2d"], device=device),
global_root_heading=global_root_heading,
)
[docs]
class FullBodyConstraintSet:
"""Constraint set fixing full-body global positions and rotations on given keyframes."""
name = "fullbody"
[docs]
def __init__(
self,
skeleton: SkeletonBase,
frame_indices: Tensor,
global_joints_positions: Tensor,
global_joints_rots: Tensor,
smooth_root_2d: Optional[Tensor] = None,
to_crop: bool = False,
):
self.skeleton = skeleton
self.frame_indices = frame_indices
# if we pass the full smooth root 3D as input
if smooth_root_2d is not None and smooth_root_2d.shape[-1] == 3:
smooth_root_2d = smooth_root_2d[..., [0, 1]]
if to_crop:
global_joints_positions = global_joints_positions[frame_indices]
global_joints_rots = global_joints_rots[frame_indices]
if smooth_root_2d is not None:
smooth_root_2d = smooth_root_2d[frame_indices]
else:
assert len(global_joints_positions) == len(
frame_indices
), "The number of global positions should be match the number of frames"
assert len(global_joints_rots) == len(
frame_indices
), "The number of global joint rotations should be match the number of frames"
if smooth_root_2d is not None:
assert len(smooth_root_2d) == len(
frame_indices
), "The number of smooth root 2d (if specified) should be match the number of frames"
if smooth_root_2d is None:
# substitute the smooth root 2d with the real root
smooth_root_2d = global_joints_positions[:, skeleton.root_idx, [0, 2]]
# root y: from smooth or pelvis is the same
self.root_y_pos = global_joints_positions[:, skeleton.root_idx, 1]
self.global_joints_positions = global_joints_positions
self.global_joints_rots = global_joints_rots
self.global_root_heading = compute_global_heading(global_joints_positions, skeleton)
self.smooth_root_2d = smooth_root_2d
[docs]
def update_constraints(self, data_dict: dict, index_dict: dict) -> None:
"""Append global positions, smooth root 2D, root y, and global heading to data/index
dicts."""
nbjoints = self.skeleton.nbjoints
indices_lst = create_pairs(
self.frame_indices,
torch.arange(nbjoints, device=self.frame_indices.device),
)
data_dict["global_joints_positions"].append(
self.global_joints_positions.reshape(-1, 3)
) # flatten the global positions
index_dict["global_joints_positions"].append(indices_lst)
# global rotations are not used here
# as we use smooth root, also constraint the smooth root to get the same full body
# maybe keep storing the hips offset, if we smooth it ourselves
data_dict["smooth_root_2d"].append(self.smooth_root_2d)
index_dict["smooth_root_2d"].append(self.frame_indices)
# constraint the y pos of the root
data_dict["root_y_pos"].append(self.root_y_pos)
index_dict["root_y_pos"].append(self.frame_indices)
# constraint the global heading
data_dict["global_root_heading"].append(self.global_root_heading)
index_dict["global_root_heading"].append(self.frame_indices)
[docs]
def crop_move(self, start: int, end: int) -> "FullBodyConstraintSet":
"""Return a new FullBodyConstraintSet for the cropped frame range [start, end)."""
mask = (self.frame_indices >= start) & (self.frame_indices < end)
return FullBodyConstraintSet(
self.skeleton,
self.frame_indices[mask] - start,
self.global_joints_positions[mask],
self.global_joints_rots[mask],
self.smooth_root_2d[mask],
)
[docs]
def get_save_info(self) -> dict:
"""Return a dict for JSON save: type, frame_indices, local_joints_rot, root_positions, smooth_root_2d."""
local_joints_rot = self.skeleton.global_rots_to_local_rots(self.global_joints_rots)
if isinstance(self.skeleton, SOMASkeleton30):
local_joints_rot = self.skeleton.to_SOMASkeleton77(local_joints_rot)
local_joints_rot = matrix_to_axis_angle(local_joints_rot)
root_positions = self.global_joints_positions[:, self.skeleton.root_idx]
return {
"type": self.name,
"frame_indices": self.frame_indices,
"local_joints_rot": local_joints_rot,
"root_positions": root_positions,
"smooth_root_2d": self.smooth_root_2d,
}
[docs]
def to(
self,
device: Optional[Union[str, torch.device]] = None,
dtype: Optional[torch.dtype] = None,
) -> "FullBodyConstraintSet":
self.frame_indices = _tensor_to(self.frame_indices, device, dtype)
self.global_joints_positions = _tensor_to(self.global_joints_positions, device, dtype)
self.global_joints_rots = _tensor_to(self.global_joints_rots, device, dtype)
self.root_y_pos = _tensor_to(self.root_y_pos, device, dtype)
self.global_root_heading = _tensor_to(self.global_root_heading, device, dtype)
self.smooth_root_2d = _tensor_to(self.smooth_root_2d, device, dtype)
if device is not None and hasattr(self.skeleton, "to"):
self.skeleton = self.skeleton.to(device)
return self
[docs]
@classmethod
def from_dict(cls, skeleton: SkeletonBase, dico: dict) -> "FullBodyConstraintSet":
"""Build a FullBodyConstraintSet from a dict (e.g. loaded from JSON)."""
frame_indices = torch.tensor(dico["frame_indices"])
device = skeleton.device if hasattr(skeleton, "device") else "cpu"
local_rot = torch.tensor(dico["local_joints_rot"], device=device)
local_rot_mats = axis_angle_to_matrix(local_rot)
local_rot_mats = _convert_constraint_local_rots_to_skeleton(local_rot_mats, skeleton)
global_joints_rots, global_joints_positions, _ = skeleton.fk(
local_rot_mats,
torch.tensor(dico["root_positions"], device=device),
)
smooth_root_2d = None
if "smooth_root_2d" in dico:
smooth_root_2d = torch.tensor(dico["smooth_root_2d"], device=device)
return cls(
skeleton,
frame_indices=frame_indices,
global_joints_positions=global_joints_positions,
global_joints_rots=global_joints_rots,
smooth_root_2d=smooth_root_2d,
)
[docs]
class EndEffectorConstraintSet:
"""Constraint set fixing selected end-effector positions and rotations on given frames."""
name = "end-effector"
[docs]
def __init__(
self,
skeleton: SkeletonBase,
frame_indices: Tensor,
global_joints_positions: Tensor,
global_joints_rots: Tensor,
smooth_root_2d: Optional[Tensor],
*,
joint_names: list[str],
to_crop: bool = False,
) -> None:
self.skeleton = skeleton
self.frame_indices = frame_indices
self.joint_names = joint_names
# joint_names are constant for all the frames
rot_joint_names, pos_joint_names = self.skeleton.expand_joint_names(self.joint_names)
# indexing works for motion_rep with smooth root only (contains pelvis index)
self.pos_indices = torch.tensor([self.skeleton.bone_index[jname] for jname in pos_joint_names])
self.rot_indices = torch.tensor([self.skeleton.bone_index[jname] for jname in rot_joint_names])
# if we pass the full smooth root 3D as input
if smooth_root_2d is not None and smooth_root_2d.shape[-1] == 3:
smooth_root_2d = smooth_root_2d[..., [0, 1]]
if to_crop:
global_joints_positions = global_joints_positions[frame_indices]
global_joints_rots = global_joints_rots[frame_indices]
if smooth_root_2d is not None:
smooth_root_2d = smooth_root_2d[frame_indices]
else:
assert len(global_joints_positions) == len(
frame_indices
), "The number of global positions should be match the number of frames"
assert len(global_joints_rots) == len(
frame_indices
), "The number of global joint rotations should be match the number of frames"
if smooth_root_2d is not None:
assert len(smooth_root_2d) == len(
frame_indices
), "The number of smooth root 2d (if specified) should be match the number of frames"
if smooth_root_2d is None:
# substitute the smooth root 2d with the real root
smooth_root_2d = global_joints_positions[:, skeleton.root_idx, [0, 2]]
# root y: from smooth or pelvis is the same
self.root_y_pos = global_joints_positions[:, skeleton.root_idx, 1]
self.global_joints_positions = global_joints_positions
self.global_root_heading = compute_global_heading(global_joints_positions, skeleton)
self.global_joints_rots = global_joints_rots
self.smooth_root_2d = smooth_root_2d
[docs]
def update_constraints(self, data_dict: dict, index_dict: dict) -> None:
"""Append constrained joint positions/rots, smooth root 2D, root y, and heading to
data/index dicts."""
crop_frames_indexing = torch.arange(len(self.frame_indices), device=self.frame_indices.device)
# constraint positions
pos_indices_real = create_pairs(
self.frame_indices,
self.pos_indices,
)
pos_indices_crop = create_pairs(
crop_frames_indexing,
self.pos_indices,
)
data_dict["global_joints_positions"].append(self.global_joints_positions[tuple(pos_indices_crop.T)])
index_dict["global_joints_positions"].append(pos_indices_real)
# constraint rotations
rot_indices_real = create_pairs(
self.frame_indices,
self.rot_indices,
)
rot_indices_crop = create_pairs(
crop_frames_indexing,
self.rot_indices,
)
data_dict["global_joints_rots"].append(self.global_joints_rots[tuple(rot_indices_crop.T)])
index_dict["global_joints_rots"].append(rot_indices_real)
# as we use smooth root, also constraint the smooth root to get the same full body
# maybe keep storing the hips offset, if we smooth it ourselves
data_dict["smooth_root_2d"].append(self.smooth_root_2d)
index_dict["smooth_root_2d"].append(self.frame_indices)
# constraint the y pos of the root
data_dict["root_y_pos"].append(self.root_y_pos)
index_dict["root_y_pos"].append(self.frame_indices)
# constraint the global heading
data_dict["global_root_heading"].append(self.global_root_heading)
index_dict["global_root_heading"].append(self.frame_indices)
[docs]
def crop_move(self, start: int, end: int) -> "EndEffectorConstraintSet":
"""Return a new EndEffectorConstraintSet for the cropped frame range [start, end)."""
mask = (self.frame_indices >= start) & (self.frame_indices < end)
cls = type(self)
kwargs = {}
if not hasattr(cls, "joint_names"):
kwargs["joint_names"] = self.joint_names
return cls(
self.skeleton,
self.frame_indices[mask] - start,
self.global_joints_positions[mask],
self.global_joints_rots[mask],
self.smooth_root_2d[mask],
**kwargs,
)
[docs]
def get_save_info(self) -> dict:
"""Return a dict for JSON save: type, frame_indices, local_joints_rot, root_positions, smooth_root_2d, joint_names."""
local_joints_rot = self.skeleton.global_rots_to_local_rots(self.global_joints_rots)
if isinstance(self.skeleton, SOMASkeleton30):
local_joints_rot = self.skeleton.to_SOMASkeleton77(local_joints_rot)
local_joints_rot = matrix_to_axis_angle(local_joints_rot)
root_positions = self.global_joints_positions[:, self.skeleton.root_idx]
output = {
"type": self.name,
"frame_indices": self.frame_indices,
"local_joints_rot": local_joints_rot,
"root_positions": root_positions,
"smooth_root_2d": self.smooth_root_2d,
}
if not hasattr(self.__class__, "joint_names"):
# save the joint_names for this base class
# but not for children
output["joint_names"] = self.joint_names
return output
[docs]
def to(
self,
device: Optional[Union[str, torch.device]] = None,
dtype: Optional[torch.dtype] = None,
) -> "EndEffectorConstraintSet":
self.frame_indices = _tensor_to(self.frame_indices, device, dtype)
self.pos_indices = _tensor_to(self.pos_indices, device, dtype)
self.rot_indices = _tensor_to(self.rot_indices, device, dtype)
self.root_y_pos = _tensor_to(self.root_y_pos, device, dtype)
self.global_joints_positions = _tensor_to(self.global_joints_positions, device, dtype)
self.global_root_heading = _tensor_to(self.global_root_heading, device, dtype)
self.global_joints_rots = _tensor_to(self.global_joints_rots, device, dtype)
self.smooth_root_2d = _tensor_to(self.smooth_root_2d, device, dtype)
if device is not None and hasattr(self.skeleton, "to"):
self.skeleton = self.skeleton.to(device)
return self
[docs]
@classmethod
def from_dict(cls, skeleton: SkeletonBase, dico: dict) -> "EndEffectorConstraintSet":
"""Build an EndEffectorConstraintSet from a dict (e.g. loaded from JSON)."""
frame_indices = torch.tensor(dico["frame_indices"])
device = skeleton.device if hasattr(skeleton, "device") else "cpu"
local_rot = torch.tensor(dico["local_joints_rot"], device=device)
local_rot_mats = axis_angle_to_matrix(local_rot)
local_rot_mats = _convert_constraint_local_rots_to_skeleton(local_rot_mats, skeleton)
global_joints_rots, global_joints_positions, _ = skeleton.fk(
local_rot_mats,
torch.tensor(dico["root_positions"], device=device),
)
smooth_root_2d = None
if "smooth_root_2d" in dico:
smooth_root_2d = torch.tensor(dico["smooth_root_2d"], device=device)
kwargs = {}
if not hasattr(cls, "joint_names"):
kwargs["joint_names"] = dico["joint_names"]
return cls(
skeleton,
frame_indices=frame_indices,
global_joints_positions=global_joints_positions,
global_joints_rots=global_joints_rots,
smooth_root_2d=smooth_root_2d,
**kwargs,
)
[docs]
class LeftHandConstraintSet(EndEffectorConstraintSet):
"""End-effector constraint for the left hand only."""
name = "left-hand"
joint_names: list[str] = ["LeftHand"]
[docs]
def __init__(self, *args, **kwargs: dict):
super().__init__(*args, joint_names=self.joint_names, **kwargs)
[docs]
class RightHandConstraintSet(EndEffectorConstraintSet):
"""End-effector constraint for the right hand only."""
name = "right-hand"
joint_names: list[str] = ["RightHand"]
[docs]
def __init__(self, *args, **kwargs: dict):
super().__init__(*args, joint_names=self.joint_names, **kwargs)
TYPE_TO_CLASS = {
"root2d": Root2DConstraintSet,
"fullbody": FullBodyConstraintSet,
"left-hand": LeftHandConstraintSet,
"right-hand": RightHandConstraintSet,
"left-foot": LeftFootConstraintSet,
"right-foot": RightFootConstraintSet,
"end-effector": EndEffectorConstraintSet,
}
[docs]
def load_constraints_lst(
path_or_data: str | list,
skeleton: SkeletonBase,
device: Optional[Union[str, torch.device]] = None,
dtype: Optional[torch.dtype] = None,
):
"""Load a list of constraints from JSON path or list of dicts.
Args:
path_or_data: Path to constraints.json or list of constraint dicts.
skeleton: Skeleton instance (used for from_dict).
device: If set, move all constraint tensors and skeleton to this device.
dtype: If set, cast constraint tensors to this dtype.
"""
if isinstance(path_or_data, str):
saved = load_json(path_or_data)
else:
saved = path_or_data
constraints_lst = []
for el in saved:
cls = TYPE_TO_CLASS[el["type"]]
c = cls.from_dict(skeleton, el)
if device is not None or dtype is not None:
c.to(device=device, dtype=dtype)
constraints_lst.append(c)
return constraints_lst
[docs]
def save_constraints_lst(path: str, constraints_lst: list) -> list | None:
"""Save a list of constraint sets to a JSON file.
Returns None if list is empty.
"""
if not constraints_lst:
print("The constraints lst is empty. Skip saving")
return
to_save = []
def tensor_to_list(obj):
"""Recursively convert tensors to lists for JSON serialization."""
if isinstance(obj, Tensor):
return obj.cpu().tolist()
elif isinstance(obj, dict):
return {k: tensor_to_list(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [tensor_to_list(v) for v in obj]
else:
return obj
for constraint in constraints_lst:
constraint_info = constraint.get_save_info()
# Convert all tensors to lists for JSON serialization
constraint_info = tensor_to_list(constraint_info)
to_save.append(constraint_info)
save_json(path, to_save)
print(f"Saved constraints to {path}")
return to_save