"""
Code for aligning 3D pose to a fly template.
The best practice for getting good alignment is to have an accurate 3D pose and
a template whose key points are matching the tracked key points closely.
This class expects a 3D file in the following format
>>> pose_data_dict = {
"<side (R,L)><segment (F, M, H)>_leg": np.ndarray[N_frames,5,3],
"<side (R,L)>_head": np.ndarray[N_frames,2,3],
"Neck": np.ndarray[N_frames,1,3],
}
Usage might differ based on the needs. For now, there are three cases that you
can use the class with:
Case 1: we have 3D pose obtained, and we would like to align it but first
we need to convert the pose data into a dictionary format
NOTE: if the 3D pose is not in the format described above, then you need to
* Convert your 3D pose file manually to the required format
* Or, if you obtain the 3D pose from anipose, simply set `convert_func`
to `convert_from_anipose_to_dict` .
>>> data_path = Path("../data/anipose_220525_aJO_Fly001_001/pose-3d")
>>> align = AlignPose.from_file_path(
>>> main_dir=data_path,
>>> file_name="pose3d.h5",
>>> legs_list=["RF","LF"],
>>> convert_func=convert_from_anipose_to_dict,
>>> pts2align=PTS2ALIGN,
>>> include_claw=False,
>>> body_template=NMF_TEMPLATE,
>>> )
>>> aligned_pos = align.align_pose(export_path=data_path)
Case 2: we have a pose data in the required data structure, we just want to load and align it
>>> data_path = Path("../data/anipose_220525_aJO_Fly001_001/pose-3d")
>>> align = AlignPose.from_file_path(
>>> main_dir=data_path,
>>> file_name="converted_pose_dict.pkl",
>>> legs_list=["RF","LF"],
>>> convert_func=None,
>>> pts2align=PTS2ALIGN,
>>> include_claw=False,
>>> body_template=NMF_TEMPLATE,
>>> )
>>> aligned_pos = align.align_pose(export_path=data_path)
Case 3: we have a pose data in the required format loaded already, we want to feed it
to the class and align the pose. This assumes that the pose data is already aligned
in the right format. If not, use the static method `convert_from_anipose`.
>>> data_path = Path("../data/anipose_220525_aJO_Fly001_001/pose-3d")
>>> f_path = data_path / "converted_pose_dict.pkl"
>>> with open(f_path, "rb") as f:
>>> pose_data = pickle.load(f)
>>> align = AlignPose(
>>> pose_data_dict=pose_data,
>>> legs_list=["RF","LF"],
>>> include_claw=False,
>>> body_template=NMF_TEMPLATE,
>>> )
>>> aligned_pos = align.align_pose(export_path=data_path)
"""
from pathlib import Path
from typing import Dict, List, Union, Optional, Literal, Callable
import pickle
import logging
import numpy as np
from seqikpy.data import PTS2ALIGN, NMF_TEMPLATE
from seqikpy.utils import save_file, calculate_body_size, dict_to_nparray_pose
logging.basicConfig(
format=" %(asctime)s - %(levelname)s- %(message)s",
handlers=[logging.StreamHandler()],
)
def _get_mean_quantile(vector, quantile_diff=0.05):
"""Returns the mean of upper and lower quantiles."""
return 0.5 * (
np.quantile(vector, q=0.5 - quantile_diff)
+ np.quantile(vector, q=0.5 + quantile_diff)
)
def _leg_length_model(nmf_size: dict, leg_name: str, claw_is_ee: bool):
"""Sums up the segments of the model leg size."""
if claw_is_ee:
return nmf_size[leg_name]
return nmf_size[leg_name] - nmf_size[f"{leg_name}_Tarsus"]
def _get_distance_btw_vecs(vector1, vector2):
"""Calculates the distance between two vectors."""
return np.linalg.norm(vector1 - vector2, axis=1)
[docs]
def convert_from_anipose_to_dict(
pose_3d: Dict[str, np.ndarray], pts2align: Dict[str, List[str]]
) -> Dict[str, np.ndarray]:
"""Loads anipose 3D pose data into a dictionary.
See data.py for a mapping from keypoint name to segment name.
Parameters
----------
pose_3d : Dict[str, np.ndarray]
3D pose data from anipose.
It should have the following format
>>> pose_3d = {
"{keypoint_name}_x" : np.ndarray[N_frames,],
"{keypoint_name}_y" : np.ndarray[N_frames,],
"{keypoint_name}_z" : np.ndarray[N_frames,],
}
pts2align : Dict[str, List[str]]
Segment names and corresponding key point names to be aligned,
check data.py for an example, by default None
Returns
-------
Dict[str, np.ndarray]
Pose data dictionary of the following format
>>> pose_data_dict = {
"RF_leg": np.ndarray[N_frames,N_key_points,3],
"LF_leg": np.ndarray[N_frames,N_key_points,3],
"R_head": np.ndarray[N_frames,N_key_points,3],
"L_head": np.ndarray[N_frames,N_key_points,3],
"Neck": np.ndarray[N_frames,N_key_points,3],
...
}
"""
points_3d_dict = {}
for segment in pts2align:
segment_kps = pts2align[segment]
temp_array = np.empty(
(pose_3d[f"{segment_kps[0]}_x"].shape[0], len(segment_kps), 3)
)
for i, kp_name in enumerate(segment_kps):
temp_array[:, i, 0] = pose_3d[f"{kp_name}_x"]
temp_array[:, i, 1] = pose_3d[f"{kp_name}_y"]
temp_array[:, i, 2] = pose_3d[f"{kp_name}_z"]
points_3d_dict[segment] = temp_array.copy()
return points_3d_dict
[docs]
def convert_from_df3d_to_dict(
pose_3d: np.ndarray, pts2align: Dict[str, np.ndarray]
) -> Dict[str, np.ndarray]:
"""Loads DeepFly3D data into a dictionary.
See the original DeepFly3D repository for indices of
key points for each segment.
Parameters
----------
pose_3d : np.ndarray
Array (N, N_key_points, 3) containing 3D pose data.
pts2align : Dict[str, np.ndarray]
Dictionary mapping segment names to key point indices.
Should be in the following format
>>> pts2align = {
"RF_leg": np.arange(0,5),
"RM_leg": np.arange(5,10),
"RH_leg": np.arange(10,15),
"LF_leg": np.arange(19,24),
"LM_leg": np.arange(24,29),
"LH_leg": np.arange(29,34),
}
Returns
-------
Dict[str, np.ndarray]
Pose data dictionary as described above.
"""
points_3d_dict = {}
for segment, segment_idx in pts2align.items():
points_3d_dict[segment] = pose_3d[:, segment_idx, :].copy()
return points_3d_dict
[docs]
def convert_from_df3dpp_to_dict(
pose_3d: Dict[str, Dict[str, np.ndarray]],
pts2align: Optional[List[str]] = None,
) -> Dict[str, np.ndarray]:
"""Load DeepFly3DPostProcessing data into a dictionary.
Parameters
----------
pose_3d : Dict[str, Dict[str, np.ndarray]]
3D pose data from DeepFly3DPostProcessing.
See the original repository for details.
pts2align : Optional[List[str]], optional
List of legs to take into account, by default None
Example format
>>> pts2align = ["RF_leg", "LF_leg"]
Returns
-------
Dict[str, np.ndarray]
Pose data dictionary as described above.
"""
points_3d_dict = {}
if pts2align is None:
pts2align = list(pose_3d.keys())
for segment in pts2align:
points_3d_dict[segment] = dict_to_nparray_pose(
pose_3d[segment], claw_is_end_effector=True
)
return points_3d_dict
[docs]
class AlignPose:
"""Aligns the 3D poses. For the class usage examples, please refer to example_alignment.py
Parameters
----------
pose_data_dict : Dict[str, np.ndarray]
3D pose put in a dictionary that has the following structure defined
by PTS2ALIGN (see data.py for more details)
Example format
>>> pose_data_dict = {
"RF_leg": np.ndarray[N_frames,N_key_points,3],
"LF_leg": np.ndarray[N_frames,N_key_points,3],
"R_head": np.ndarray[N_frames,N_key_points,3],
"L_head": np.ndarray[N_frames,N_key_points,3],
"Neck": np.ndarray[N_frames,N_key_points,3],
}
legs_list : List[str]
A list containing the leg names to operate on.
Should follow the convention <R or L><F or M or H>
e.g., ["RF", "LF", "RM", "LM", "RH", "LH"]
include_claw : bool, optional
If True, claw is included in the scaling process, by default False
body_template : Dict[str, np.ndarray], optional
A dictionary containing the positions of fly model body segments.
Check ./data.py for the default dictionary, by default None
body_size : Dict[str, float], optional
A dictionary containing the limb size of the fly.
If the user wants to scale the animal data to match the
biomechanical model, then `body_size` should be the same as
the model body size. Otherwise, the user should calculate the
animal's body size.
Check ./data.py for an example.
log_level : Literal["DEBUG", "INFO", "WARNING", "ERROR"], optional
Logging level as a string, by default "INFO"
"""
def __init__(
self,
pose_data_dict: Dict[str, np.ndarray],
legs_list: List[str],
include_claw: Optional[bool] = False,
body_template: Optional[Dict[str, np.ndarray]] = None,
body_size: Optional[Dict[str, float]] = None,
log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR"] = "INFO",
) -> None:
self.pose_data_dict = pose_data_dict
self.include_claw = include_claw
self.body_template = NMF_TEMPLATE if body_template is None else body_template
# Calculate the size of the limbs from the template
if body_size is None:
self.body_size = calculate_body_size(self.body_template, legs_list)
else:
self.body_size = body_size
# Get the logger of the module
self.logger = logging.getLogger(self.__class__.__name__)
numeric_level = getattr(logging, log_level.upper(), None)
self.logger.setLevel(numeric_level)
[docs]
@classmethod
def from_file_path(
cls,
main_dir: Union[str, Path],
file_name: Optional[str] = "pose3d.*",
convert_func: Optional[Callable] = None,
pts2align: Optional[Dict[str, List[str]]] = None,
**kwargs,
):
"""
Class method to load pose3d data and convert it into a proper
structure.
Parameters
----------
main_dir : Union[str, Path]
Path where the Anipose triangulation results are saved.
By default, the result file is caled pose3d.h5
However, if the name is different, <pose_result_path> should
be modified accordingly.
Example: "../Fly001/001_Beh/behData/pose-3d"
file_name : str, optional
File name, by default "pose3d.*"
convert_func : Callable, optional
Function to convert the loaded pose into the required format,
set by the user if None, no conversion is performed,
by default None
pts2align : Dict[str, List[str]], optional
Body part names and corresponding key points names to be aligned,
check data.py for an example, by default None
Returns
-------
AlignPose
Instance of the AlignPose class.
Raises
------
FileNotFoundError
If file with a name that contains `file_name` does not exist in `main_dir`.
"""
paths = list(Path(main_dir).rglob(file_name))
if len(paths) > 0:
with open(paths[-1].as_posix(), "rb") as f:
pose_3d = pickle.load(f)
else:
raise FileNotFoundError(f"{file_name} does not exits in {main_dir}")
if convert_func is not None:
pts2align = PTS2ALIGN if pts2align is None else pts2align
converted_dict = convert_func(pose_3d, pts2align)
return cls(converted_dict, **kwargs)
return cls(pose_3d, **kwargs)
[docs]
def align_pose(
self, export_path: Optional[Union[str, Path]] = None
) -> Dict[str, np.ndarray]:
"""Aligns the leg and head key point positions.
Parameters
----------
export_path : Union[str, Path], optional
The path where the aligned pose data will be saved, if specified.
Returns
-------
Dict[str, np.ndarray]
A dictionary containing the aligned pose data.
"""
aligned_pose = {}
for segment, segment_array in self.pose_data_dict.items():
if "leg" in segment:
aligned_pose[segment] = self.align_leg(
leg_array=segment_array, leg_name=segment[:2]
)
elif "head" in segment:
aligned_pose[segment] = self.align_head(
head_array=segment_array, side=segment[0]
)
else:
self.logger.debug("%s is not aligned", segment)
continue
# Take the neck as in the template as the other points are already aligned
if "Neck" in self.body_template:
aligned_pose["Neck"] = self.body_template["Neck"].reshape((-1, 1, 3))
if export_path is not None:
export_full_path = export_path / "pose3d_aligned.pkl"
save_file(out_fname=export_full_path, data=aligned_pose)
self.logger.info("Aligned pose is saved at %s", export_path)
return aligned_pose
@property
def thorax_mid_pts(self) -> np.ndarray:
"""Gets the middle point of right and left wing hinges."""
assert (
"Thorax" in self.pose_data_dict
), "To align the head, you need to have a `Thorax` key point"
thorax_pts = self.pose_data_dict["Thorax"]
return 0.5 * (thorax_pts[:, 0, :] + thorax_pts[:, -1, :])
[docs]
@staticmethod
def get_fixed_pos(points_3d: np.ndarray) -> np.ndarray:
"""Gets the fixed pose of a steady key point determined by the quantiles."""
fixed_pos = [
_get_mean_quantile(points_3d[:, 0]),
_get_mean_quantile(points_3d[:, 1]),
_get_mean_quantile(points_3d[:, 2]),
]
return np.array(fixed_pos)
[docs]
def get_mean_length(
self, segment_array: np.ndarray, segment_is_leg: bool
) -> Dict[str, float]:
"""Computes the mean length of a body segment."""
lengths = np.linalg.norm(np.diff(segment_array, axis=1), axis=2)
if segment_is_leg:
segments = ["coxa", "femur", "tibia", "tarsus"]
else:
segments = ["antenna"]
length_mean = {}
for i, s in enumerate(segments):
length_mean[s] = _get_mean_quantile(lengths[:, i])
return length_mean
[docs]
def find_scale_leg(self, leg_name: str, mean_length: Dict) -> float:
"""Computes the ratio between the model size and the real fly size."""
nmf_size = _leg_length_model(self.body_size, leg_name, self.include_claw)
fly_leg_size = mean_length["coxa"] + mean_length["femur"] + mean_length["tibia"]
fly_leg_size += mean_length["tarsus"] if self.include_claw else 0
return nmf_size / fly_leg_size
[docs]
def find_stationary_indices(
self, array: np.ndarray, threshold: Optional[float] = 5e-5
) -> np.ndarray:
"""Find the indices in an array where the function value does not move significantly."""
indices_stat = np.where((np.diff(np.diff(array)) < threshold))
assert (
indices_stat
), f"Threshold ({threshold}) is too low to find stationary points, please increase it."
return indices_stat[0]
[docs]
def align_leg(
self,
leg_array: np.ndarray,
leg_name: Literal["RF", "LF", "RM", "LM", "RH", "LH"],
) -> np.ndarray:
"""Scales and translates the leg key point locations based on the model size and configuration.
This method takes a 3D array of leg key point positions and scales and translates them to align with a
predefined model size and joint configuration. It accounts for the relative positions
of key points and ensures that the scaled leg key points match the model size.
Parameters
----------
leg_array : np.ndarray
A 3D array containing the leg key point positions.
leg_name : str
A string indicating the name of the leg (e.g., "RF", "LF", ...) for alignment.
Returns
-------
np.ndarray
A new 3D array containing the scaled and aligned leg key point positions.
Notes
-----
* This method is used to align leg key point positions with a model of a fly"s leg.
* It calculates the scale factor and multiplies the first 4 or 5 segments with the scale factor.
"""
aligned_array = np.empty_like(leg_array)
fixed_coxa = AlignPose.get_fixed_pos(leg_array[:, 0, :])
mean_length = self.get_mean_length(leg_array, segment_is_leg=True)
scale_factor = self.find_scale_leg(leg_name, mean_length)
self.logger.info("Scale factor for %s leg: %s", leg_name, scale_factor)
for i in range(0, 5):
if i == 0:
# Translate the 3D pose coxa to the template coxa position
aligned_array[:, i, :] = (
np.zeros_like(leg_array[:, i, :])
+ self.body_template[f"{leg_name}_Coxa"]
)
else:
# Scale the length of the leg and
# move the leg to the predefined coxa pos
pos_aligned = (leg_array[:, i, :] - fixed_coxa).reshape(
-1, 3
) * scale_factor + self.body_template[f"{leg_name}_Coxa"]
aligned_array[:, i, :] = pos_aligned
return aligned_array.copy()
[docs]
def align_head(self, head_array: np.ndarray, side: str) -> np.ndarray:
"""Scales and translates the head key point locations based on the model size and configuration.
This method takes a 3D array of head key point positions and scales and translates
them to align with a predefined model size and configuration, such as a fly"s head.
It accounts for the relative positions of key points and ensures that the scaled
head key points match the model size.
Parameters
----------
head_array : np.ndarray
A 3D array containing the head key point positions.
side : str
A string indicating the side of the head (e.g., "R" or "L") for alignment.
Returns
-------
np.ndarray
A new 3D array containing the scaled and aligned head key point positions.
Raises
------
KeyError
If the "body_template" dictionary does not contain the required key names,
"Antenna_mid_thorax" or "Antenna".
Notes
-----
* This method is used to align head key point positions with a model of a fly"s head.
* It calculates the scale factor and translations necessary to match the model's head size and position.
"""
antbase2thoraxmid_real = _get_distance_btw_vecs(
head_array[:, 0, :], self.thorax_mid_pts
)
ant_size = self.get_mean_length(head_array, segment_is_leg=False)["antenna"]
if self.body_size["Antenna_mid_thorax"] and self.body_size["Antenna"]:
antbase2thoraxmid_tmp = self.body_size["Antenna_mid_thorax"]
ant_tmp = self.body_size["Antenna"]
else:
raise KeyError(
"""Nmf template dictionary does not contain
a key name <Antenna_mid_thorax> or <Antenna>
Please check the dictionary you provided."""
)
stationary_indices = self.find_stationary_indices(antbase2thoraxmid_real)
antenna_origin_fixed = AlignPose.get_fixed_pos(
head_array[stationary_indices, 0, :]
)
scale_base_ant = antbase2thoraxmid_tmp / _get_mean_quantile(
antbase2thoraxmid_real[stationary_indices]
)
scale_tip_ant = ant_tmp / _get_mean_quantile(ant_size)
self.logger.info(
"Scale factor antenna base %s: %s, ant itself: %s",
side,
scale_base_ant,
scale_tip_ant,
)
aligned_array = np.empty_like(head_array)
aligned_array[:, 0, :] = (
head_array[:, 0, :] - antenna_origin_fixed
) * scale_base_ant + self.body_template[f"{side}_Antenna_base"]
aligned_array[:, 1, :] = (
head_array[:, 1, :] - antenna_origin_fixed
) * scale_tip_ant + self.body_template[f"{side}_Antenna_base"]
return aligned_array.copy()