Spaces:
Paused
Paused
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from __future__ import annotations # so we can refer to class Type inside class | |
| import numpy as np | |
| import numpy.typing as npt | |
| from animated_drawings.model.vectors import Vectors | |
| from animated_drawings.model.quaternions import Quaternions | |
| import logging | |
| from typing import Union, Optional, List, Tuple | |
| class Transform(): | |
| """Base class from which all other scene objects descend""" | |
| def __init__(self, | |
| parent: Optional[Transform] = None, | |
| name: Optional[str] = None, | |
| children: List[Transform] = [], | |
| offset: Union[npt.NDArray[np.float32], Vectors, None] = None, | |
| **kwargs | |
| ) -> None: | |
| super().__init__(**kwargs) | |
| self._parent: Optional[Transform] = parent | |
| self._children: List[Transform] = [] | |
| for child in children: | |
| self.add_child(child) | |
| self.name: Optional[str] = name | |
| self._translate_m: npt.NDArray[np.float32] = np.identity(4, dtype=np.float32) | |
| self._rotate_m: npt.NDArray[np.float32] = np.identity(4, dtype=np.float32) | |
| self._scale_m: npt.NDArray[np.float32] = np.identity(4, dtype=np.float32) | |
| if offset is not None: | |
| self.offset(offset) | |
| self._local_transform: npt.NDArray[np.float32] = np.identity(4, dtype=np.float32) | |
| self._world_transform: npt.NDArray[np.float32] = np.identity(4, dtype=np.float32) | |
| self.dirty_bit: bool = True # are world/local transforms stale? | |
| def update_transforms(self, parent_dirty_bit: bool = False, recurse_on_children: bool = True, update_ancestors: bool = False) -> None: | |
| """ | |
| Updates transforms if stale. | |
| If own dirty bit is set, recompute local matrix | |
| If own or parent's dirty bit is set, recompute world matrix | |
| If own or parent's dirty bit is set, recurses on children, unless param recurse_on_children is false. | |
| If update_ancestors is true, first find first ancestor, then call update_transforms upon it. | |
| Set dirty bit back to false. | |
| """ | |
| if update_ancestors: | |
| ancestor, ancestor_parent = self, self.get_parent() | |
| while ancestor_parent is not None: | |
| ancestor, ancestor_parent = ancestor_parent, ancestor_parent.get_parent() | |
| ancestor.update_transforms() | |
| if self.dirty_bit: | |
| self.compute_local_transform() | |
| if self.dirty_bit | parent_dirty_bit: | |
| self.compute_world_transform() | |
| if recurse_on_children: | |
| for c in self.get_children(): | |
| c.update_transforms(self.dirty_bit | parent_dirty_bit) | |
| self.dirty_bit = False | |
| def compute_local_transform(self) -> None: | |
| self._local_transform = self._translate_m @ self._rotate_m @ self._scale_m | |
| def compute_world_transform(self) -> None: | |
| self._world_transform = self._local_transform | |
| if self._parent: | |
| self._world_transform = self._parent._world_transform @ self._world_transform | |
| def get_world_transform(self, update_ancestors: bool = True) -> npt.NDArray[np.float32]: | |
| """ | |
| Get the transform's world matrix. | |
| If update is true, check to ensure the world_transform is current | |
| """ | |
| if update_ancestors: | |
| self.update_transforms(update_ancestors=True) | |
| return np.copy(self._world_transform) | |
| def set_scale(self, scale: float) -> None: | |
| self._scale_m[:-1, :-1] = scale * np.identity(3, dtype=np.float32) | |
| self.dirty_bit = True | |
| def set_position(self, pos: Union[npt.NDArray[np.float32], Vectors]) -> None: | |
| """ Set the absolute values of the translational elements of transform """ | |
| if isinstance(pos, Vectors): | |
| pos = pos.vs | |
| if pos.shape == (1, 3): | |
| pos = np.squeeze(pos) | |
| elif pos.shape == (3,): | |
| pass | |
| else: | |
| msg = f'bad vector dim passed to set_position. Found: {pos.shape}' | |
| logging.critical(msg) | |
| assert False, msg | |
| self._translate_m[:-1, -1] = pos | |
| self.dirty_bit = True | |
| def get_local_position(self) -> npt.NDArray[np.float32]: | |
| """ Ensure local transform is up-to-date and return local xyz coordinates """ | |
| if self.dirty_bit: | |
| self.compute_local_transform() | |
| return np.copy(self._local_transform[:-1, -1]) | |
| def get_world_position(self, update_ancestors: bool = True) -> npt.NDArray[np.float32]: | |
| """ | |
| Ensure all parent transforms are update and return world xyz coordinates | |
| If update_ancestor_transforms is true, update ancestor transforms to ensure | |
| up-to-date world_transform before returning | |
| """ | |
| if update_ancestors: | |
| self.update_transforms(update_ancestors=True) | |
| return np.copy(self._world_transform[:-1, -1]) | |
| def offset(self, pos: Union[npt.NDArray[np.float32], Vectors]) -> None: | |
| """ Translational offset by the specified amount """ | |
| if isinstance(pos, Vectors): | |
| pos = pos.vs[0] | |
| assert isinstance(pos, np.ndarray) | |
| self.set_position(self._translate_m[:-1, -1] + pos) | |
| def look_at(self, fwd_: Union[npt.NDArray[np.float32], Vectors, None]) -> None: | |
| """Given a forward vector, rotate the transform to face that position""" | |
| if fwd_ is None: | |
| fwd_ = Vectors(self.get_world_position()) | |
| elif isinstance(fwd_, np.ndarray): | |
| fwd_ = Vectors(fwd_) | |
| fwd: Vectors = fwd_.copy() # norming will change the vector | |
| if fwd.vs.shape != (1, 3): | |
| msg = f'look_at fwd_ vector must have shape [1,3]. Found: {fwd.vs.shape}' | |
| logging.critical(msg) | |
| assert False, msg | |
| tmp: Vectors = Vectors([0.0, 1.0, 0.0]) | |
| # if fwd and tmp are same vector, modify tmp to avoid collapse | |
| if np.isclose(fwd.vs, tmp.vs).all() or np.isclose(fwd.vs, -tmp.vs).all(): | |
| tmp.vs[0] += 0.001 | |
| right: Vectors = tmp.cross(fwd) | |
| up: Vectors = fwd.cross(right) | |
| fwd.norm() | |
| right.norm() | |
| up.norm() | |
| rotate_m = np.identity(4, dtype=np.float32) | |
| rotate_m[:-1, 0] = np.squeeze(right.vs) | |
| rotate_m[:-1, 1] = np.squeeze(up.vs) | |
| rotate_m[:-1, 2] = np.squeeze(fwd.vs) | |
| self._rotate_m = rotate_m | |
| self.dirty_bit = True | |
| def get_right_up_fwd_vectors(self) -> Tuple[npt.NDArray[np.float32], npt.NDArray[np.float32], npt.NDArray[np.float32]]: | |
| inverted: npt.NDArray[np.float32] = np.linalg.inv(self.get_world_transform()) | |
| right: npt.NDArray[np.float32] = inverted[:-1, 0] | |
| up: npt.NDArray[np.float32] = inverted[:-1, 1] | |
| fwd: npt.NDArray[np.float32] = inverted[:-1, 2] | |
| return right, up, fwd | |
| def set_rotation(self, q: Quaternions) -> None: | |
| if q.qs.shape != (1, 4): | |
| msg = f'set_rotate q must have dimension (1, 4). Found: {q.qs.shape}' | |
| logging.critical(msg) | |
| assert False, msg | |
| self._rotate_m = q.to_rotation_matrix() | |
| self.dirty_bit = True | |
| def rotation_offset(self, q: Quaternions) -> None: | |
| if q.qs.shape != (1, 4): | |
| msg = f'set_rotate q must have dimension (1, 4). Found: {q.qs.shape}' | |
| logging.critical(msg) | |
| assert False, msg | |
| self._rotate_m = (q * Quaternions.from_rotation_matrix(self._rotate_m)).to_rotation_matrix() | |
| self.dirty_bit = True | |
| def add_child(self, child: Transform) -> None: | |
| self._children.append(child) | |
| child.set_parent(self) | |
| def get_children(self) -> List[Transform]: | |
| return self._children | |
| def set_parent(self, parent: Transform) -> None: | |
| self._parent = parent | |
| self.dirty_bit = True | |
| def get_parent(self) -> Optional[Transform]: | |
| return self._parent | |
| def get_transform_by_name(self, name: str) -> Optional[Transform]: | |
| """ Search self and children for transform with matching name. Return it if found, None otherwise. """ | |
| # are we match? | |
| if self.name == name: | |
| return self | |
| # recurse to check if a child is match | |
| for c in self.get_children(): | |
| transform_or_none = c.get_transform_by_name(name) | |
| if transform_or_none: # if we found it | |
| return transform_or_none | |
| # no match | |
| return None | |
| def draw(self, recurse: bool = True, **kwargs) -> None: | |
| """ Draw this transform and recurse on children """ | |
| self._draw(**kwargs) | |
| if recurse: | |
| for child in self.get_children(): | |
| child.draw(**kwargs) | |
| def _draw(self, **kwargs) -> None: | |
| """Transforms default to not being drawn. Subclasses must implement how they appear""" | |