Spaces:
Paused
Paused
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from __future__ import annotations # so we can refer to class Type inside class | |
| import logging | |
| from pathlib import Path | |
| from typing import List, Tuple, Optional | |
| import numpy as np | |
| import numpy.typing as npt | |
| from animated_drawings.model.transform import Transform | |
| from animated_drawings.model.box import Box | |
| from animated_drawings.model.quaternions import Quaternions | |
| from animated_drawings.model.vectors import Vectors | |
| from animated_drawings.model.joint import Joint | |
| from animated_drawings.model.time_manager import TimeManager | |
| from animated_drawings.utils import resolve_ad_filepath | |
| class BVH_Joint(Joint): | |
| """ | |
| Joint class with channel order attribute and specialized vis widget | |
| """ | |
| def __init__(self, channel_order: List[str] = [], widget: bool = True, **kwargs) -> None: | |
| super().__init__(**kwargs) | |
| self.channel_order = channel_order | |
| self.widget: Optional[Transform] = None | |
| if widget: | |
| self.widget = Box() | |
| self.add_child(self.widget) | |
| def _draw(self, **kwargs): | |
| if self.widget: | |
| self.widget.draw(**kwargs) | |
| class BVH(Transform, TimeManager): | |
| """ | |
| Class to encapsulate BVH (Biovision Hierarchy) animation data. | |
| Include a single skeletal hierarchy defined in the BVH, frame count and speed, | |
| and skeletal pos/rot data for each frame | |
| """ | |
| def __init__(self, | |
| name: str, | |
| root_joint: BVH_Joint, | |
| frame_max_num: int, | |
| frame_time: float, | |
| pos_data: npt.NDArray[np.float32], | |
| rot_data: npt.NDArray[np.float32] | |
| ) -> None: | |
| """ | |
| Don't recommend calling this method directly. Instead, use BVH.from_file(). | |
| """ | |
| super().__init__() | |
| self.name: str = name | |
| self.frame_max_num: int = frame_max_num | |
| self.frame_time: float = frame_time | |
| self.pos_data: npt.NDArray[np.float32] = pos_data | |
| self.rot_data: npt.NDArray[np.float32] = rot_data | |
| self.root_joint = root_joint | |
| self.add_child(self.root_joint) | |
| self.joint_num = self.root_joint.joint_count() | |
| self.cur_frame = 0 # initialize skeleton pose to first frame | |
| self.apply_frame(self.cur_frame) | |
| def get_joint_names(self) -> List[str]: | |
| """ Get names of joints in skeleton in the order in which BVH rotation data is stored. """ | |
| return self.root_joint.get_chain_joint_names() | |
| def update(self) -> None: | |
| """Based upon internal time, determine which frame should be displayed and apply it""" | |
| cur_time: float = self.get_time() | |
| cur_frame = round(cur_time / self.frame_time) % self.frame_max_num | |
| self.apply_frame(cur_frame) | |
| def apply_frame(self, frame_num: int) -> None: | |
| """ Apply root position and joint rotation data for specified frame_num """ | |
| self.root_joint.set_position(self.pos_data[frame_num]) | |
| self._apply_frame_rotations(self.root_joint, frame_num, ptr=np.array(0)) | |
| def _apply_frame_rotations(self, joint: BVH_Joint, frame_num: int, ptr: npt.NDArray[np.int32]) -> None: | |
| q = Quaternions(self.rot_data[frame_num, ptr]) | |
| joint.set_rotation(q) | |
| ptr += 1 | |
| for c in joint.get_children(): | |
| if not isinstance(c, BVH_Joint): | |
| continue | |
| self._apply_frame_rotations(c, frame_num, ptr) | |
| def get_skeleton_fwd(self, forward_perp_vector_joint_names: List[Tuple[str, str]], update: bool = True) -> Vectors: | |
| """ | |
| Get current forward vector of skeleton in world coords. If update=True, ensure skeleton transforms are current. | |
| Input forward_perp_vector_joint_names, a list of pairs of joint names (e.g. [[leftshould, rightshoulder], [lefthip, righthip]]) | |
| Finds average of vectors between joint pairs, then returns vector perpendicular to their average. | |
| """ | |
| if update: | |
| self.root_joint.update_transforms(update_ancestors=True) | |
| vectors_cw_perpendicular_to_fwd: List[Vectors] = [] | |
| for (start_joint_name, end_joint_name) in forward_perp_vector_joint_names: | |
| start_joint = self.root_joint.get_transform_by_name(start_joint_name) | |
| if not start_joint: | |
| msg = f'Could not find BVH joint with name: {start_joint_name}' | |
| logging.critical(msg) | |
| assert False, msg | |
| end_joint = self.root_joint.get_transform_by_name(end_joint_name) | |
| if not end_joint: | |
| msg = f'Could not find BVH joint with name: {end_joint_name}' | |
| logging.critical(msg) | |
| assert False, msg | |
| bone_vector: Vectors = Vectors(end_joint.get_world_position()) - Vectors(start_joint.get_world_position()) | |
| bone_vector.norm() | |
| vectors_cw_perpendicular_to_fwd.append(bone_vector) | |
| return Vectors(vectors_cw_perpendicular_to_fwd).average().perpendicular() | |
| def from_file(cls, bvh_fn: str, start_frame_idx: int = 0, end_frame_idx: Optional[int] = None) -> BVH: | |
| """ Given a path to a .bvh, constructs and returns BVH object""" | |
| # search for the BVH file specified | |
| bvh_p: Path = resolve_ad_filepath(bvh_fn, 'bvh file') | |
| logging.info(f'Using BVH file located at {bvh_p.resolve()}') | |
| with open(str(bvh_p), 'r') as f: | |
| lines = f.read().splitlines() | |
| if lines.pop(0) != 'HIERARCHY': | |
| msg = f'Malformed BVH in line preceding {lines}' | |
| logging.critical(msg) | |
| assert False, msg | |
| # Parse the skeleton | |
| root_joint: BVH_Joint = BVH._parse_skeleton(lines) | |
| if lines.pop(0) != 'MOTION': | |
| msg = f'Malformed BVH in line preceding {lines}' | |
| logging.critical(msg) | |
| assert False, msg | |
| # Parse motion metadata | |
| frame_max_num = int(lines.pop(0).split(':')[-1]) | |
| frame_time = float(lines.pop(0).split(':')[-1]) | |
| # Parse motion data | |
| frames = [list(map(float, line.strip().split(' '))) for line in lines] | |
| if len(frames) != frame_max_num: | |
| msg = f'framenum specified ({frame_max_num}) and found ({len(frames)}) do not match' | |
| logging.critical(msg) | |
| assert False, msg | |
| # Split logically distinct root position data from joint euler angle rotation data | |
| pos_data: npt.NDArray[np.float32] | |
| rot_data: npt.NDArray[np.float32] | |
| pos_data, rot_data = BVH._process_frame_data(root_joint, frames) | |
| # Set end_frame if not passed in | |
| if not end_frame_idx: | |
| end_frame_idx = frame_max_num | |
| # Ensure end_frame_idx <= frame_max_num | |
| if frame_max_num < end_frame_idx: | |
| msg = f'config specified end_frame_idx > bvh frame_max_num ({end_frame_idx} > {frame_max_num}). Replacing with frame_max_num.' | |
| logging.warning(msg) | |
| end_frame_idx = frame_max_num | |
| # slice position and rotation data using start and end frame indices | |
| pos_data = pos_data[start_frame_idx:end_frame_idx, :] | |
| rot_data = rot_data[start_frame_idx:end_frame_idx, :] | |
| # new frame_max_num based is end_frame_idx minus start_frame_idx | |
| frame_max_num = end_frame_idx - start_frame_idx | |
| return BVH(bvh_p.name, root_joint, frame_max_num, frame_time, pos_data, rot_data) | |
| def _parse_skeleton(cls, lines: List[str]) -> BVH_Joint: | |
| """ | |
| Called recursively to parse and construct skeleton from BVH | |
| :param lines: partially-processed contents of BVH file. Is modified in-place. | |
| :return: Joint | |
| """ | |
| # Get the joint name | |
| if lines[0].strip().startswith('ROOT'): | |
| _, joint_name = lines.pop(0).strip().split(' ') | |
| elif lines[0].strip().startswith('JOINT'): | |
| _, joint_name = lines.pop(0).strip().split(' ') | |
| elif lines[0].strip().startswith('End Site'): | |
| joint_name = lines.pop(0).strip() | |
| else: | |
| msg = f'Malformed BVH. Line: {lines[0]}' | |
| logging.critical(msg) | |
| assert False, msg | |
| if lines.pop(0).strip() != '{': | |
| msg = f'Malformed BVH in line preceding {lines}' | |
| logging.critical(msg) | |
| assert False, msg | |
| # Get offset | |
| if not lines[0].strip().startswith('OFFSET'): | |
| msg = f'Malformed BVH in line preceding {lines}' | |
| logging.critical(msg) | |
| assert False, msg | |
| _, *xyz = lines.pop(0).strip().split(' ') | |
| offset = Vectors(list(map(float, xyz))) | |
| # Get channels | |
| if lines[0].strip().startswith('CHANNELS'): | |
| channel_order = lines.pop(0).strip().split(' ') | |
| _, channel_num, *channel_order = channel_order | |
| else: | |
| channel_num, channel_order = 0, [] | |
| if int(channel_num) != len(channel_order): | |
| msg = f'Malformed BVH in line preceding {lines}' | |
| logging.critical(msg) | |
| assert False, msg | |
| # Recurse for children | |
| children: List[BVH_Joint] = [] | |
| while lines[0].strip() != '}': | |
| children.append(BVH._parse_skeleton(lines)) | |
| lines.pop(0) # } | |
| return BVH_Joint(name=joint_name, offset=offset, channel_order=channel_order, children=children) | |
| def _process_frame_data(cls, skeleton: BVH_Joint, frames: List[List[float]]) -> Tuple[npt.NDArray[np.float32], npt.NDArray[np.float32]]: | |
| """ Given skeleton and frame data, return root position data and joint quaternion data, separately""" | |
| def _get_frame_channel_order(joint: BVH_Joint, channels=[]): | |
| channels.extend(joint.channel_order) | |
| for child in [child for child in joint.get_children() if isinstance(child, BVH_Joint)]: | |
| _get_frame_channel_order(child, channels) | |
| return channels | |
| channels = _get_frame_channel_order(skeleton) | |
| # create a mask so we retain only joint rotations and root position | |
| mask = np.array(list(map(lambda x: True if 'rotation' in x else False, channels))) | |
| mask[:3] = True # hack to make sure we keep root position | |
| frames = np.array(frames, dtype=np.float32)[:, mask] | |
| # split root pose data and joint euler angle data | |
| pos_data, ea_rots = np.split(np.array(frames, dtype=np.float32), [3], axis=1) | |
| # quaternion rot data will go here | |
| rot_data = np.empty([len(frames), skeleton.joint_count(), 4], dtype=np.float32) | |
| BVH._pose_ea_to_q(skeleton, ea_rots, rot_data) | |
| return pos_data, rot_data | |
| def _pose_ea_to_q(cls, joint: BVH_Joint, ea_rots: npt.NDArray[np.float32], q_rots: npt.NDArray[np.float32], p1: int = 0, p2: int = 0) -> Tuple[int, int]: | |
| """ | |
| Given joint and array of euler angle rotation data, converts to quaternions and stores in q_rots. | |
| Only called by _process_frame_data(). Modifies q_rots inplace. | |
| :param p1: pointer to find where in ea_rots to read euler angles from | |
| :param p2: pointer to determine where in q_rots to input quaternion | |
| """ | |
| axis_chars = "".join([c[0].lower() for c in joint.channel_order if c.endswith('rotation')]) # e.g. 'xyz' | |
| q_rots[:, p2] = Quaternions.from_euler_angles(axis_chars, ea_rots[:, p1:p1+len(axis_chars)]).qs | |
| p1 += len(axis_chars) | |
| p2 += 1 | |
| for child in joint.get_children(): | |
| if isinstance(child, BVH_Joint): | |
| p1, p2 = BVH._pose_ea_to_q(child, ea_rots, q_rots, p1, p2) | |
| return p1, p2 | |