""" Functions for plotting and animation. """
import logging
import subprocess
import warnings
import itertools
from pathlib import Path
from typing import Tuple, List, Dict, Optional, Iterable
from tqdm import tqdm
import cv2
import numpy as np
from matplotlib import animation
from matplotlib.gridspec import GridSpec
import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from seqikpy.utils import load_file
# Ignore the warnings
warnings.filterwarnings("ignore")
# Change the logging level here
logging.basicConfig(format=" %(asctime)s - %(levelname)s- %(message)s")
# Get the logger of the module
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
[docs]
def generate_color_map(cmap: str, n: int) -> List:
"""Generates a list of colors from a given colormap."""
cmap = plt.get_cmap(cmap)
colors = cmap(np.linspace(0, 1, n))
return colors
[docs]
def get_video_writer(video_path, fps, output_shape):
"""Initialize and return a VideoWriter object."""
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
return cv2.VideoWriter(video_path, fourcc, fps, output_shape[::-1])
[docs]
def resize_frame(original_size, target_size):
"""
Resize a frame to the target size
while preserving the width length ratio.
"""
if target_size[0] == -1 and target_size[1] == -1:
new_size = original_size
elif target_size[0] == -1:
ratio = original_size[0] / original_size[1]
new_size = (int(target_size[1] * ratio), target_size[1])
elif target_size[1] == -1:
ratio = original_size[1] / original_size[0]
new_size = (target_size[0], int(target_size[0] * ratio))
else:
new_size = target_size
if new_size[0] > original_size[0] and new_size[1] > original_size[1]:
return original_size
return new_size
[docs]
def resize_rgb_frame(frame, output_shape):
"""resize the frame and convert it to RGB"""
resized = cv2.resize(frame, output_shape[::-1])
return cv2.cvtColor(resized, cv2.COLOR_BGR2RGB)
[docs]
def make_video(
video_path: str,
frame_generator: Iterable,
fps: int,
output_shape: Tuple = (-1, 2880),
n_frames: int = -1,
):
"""Makes videos from a generator of images."""
first_frame = next(frame_generator)
frame_generator = itertools.chain([first_frame], frame_generator)
adjusted_output_shape = resize_frame(first_frame.shape[:2], output_shape)
video_writer = get_video_writer(video_path, fps, adjusted_output_shape)
for frame_count, frame in tqdm(enumerate(frame_generator)):
rgb_frame = resize_rgb_frame(frame, adjusted_output_shape)
video_writer.write(rgb_frame)
if 0 < n_frames == frame_count + 1:
break
video_writer.release()
logging.info("Video is saved at %s", video_path)
[docs]
def video_frames_generator(
video_path: Path,
start_frame: int,
end_frame: int,
stim_lines: List[int],
radius=30,
center=(50, 50),
color=(255, 0, 0),
):
"""Returns the frames as a generator in a given interval.
Modifies the brightness and contrast of the images.
Parameters
----------
video_path : Path
Video path.
start_frame : int
Starting frame.
end_frame : int
End frame.
Yields
------
Frame
Generator containing all the frames in the specified interval.
"""
cap = cv2.VideoCapture(str(video_path))
cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
# Define the contrast and brightness values
alpha = 1.15 # Contrast control
beta = -5 # Brightness control
for t in range(start_frame, end_frame):
ret, frame = cap.read()
adjusted_frame = cv2.convertScaleAbs(frame, alpha=alpha, beta=beta)
# if stimulation, add a red dot
if stim_lines[0] <= t <= stim_lines[-1]:
adjusted_frame = cv2.circle(adjusted_frame, center, radius, color, -1)
if not ret:
break
yield adjusted_frame
cap.release()
[docs]
def get_plot_config(data_path: Path):
"""Get experimental conditions from the data path.
Data path should look like:
"/mnt/nas2/GO/7cam/220810_aJO-GAL4xUAS-CsChr/Fly001/001_RLF/behData/pose-3d"
"""
assert data_path.parts[-1] == "pose-3d", "The data path should end with pose-3d"
assert data_path.parts[-2] == "behData", "The data path should contain behData"
plot_config = {}
trial_type = data_path.parts[-3]
if "_RLF_coxa" in trial_type:
exp_type = "coxa"
plot_config["plot_head"] = True
plot_config["plot_right_leg"] = True
plot_config["plot_left_leg"] = True
plot_config["azim"] = 23
elif "_RLF" in trial_type:
exp_type = "RLF"
plot_config["plot_head"] = True
plot_config["plot_right_leg"] = False
plot_config["plot_left_leg"] = False
elif "_LF" in trial_type:
exp_type = "LF"
plot_config["plot_head"] = True
plot_config["plot_right_leg"] = True
plot_config["plot_left_leg"] = False
elif "_RF" in trial_type:
exp_type = "RF"
plot_config["plot_head"] = True
plot_config["plot_right_leg"] = False
plot_config["plot_left_leg"] = True
else:
exp_type = "Beh"
plot_config["plot_head"] = True
plot_config["plot_right_leg"] = True
plot_config["plot_left_leg"] = True
return exp_type, plot_config
[docs]
def get_frames_from_video_ffmpeg(path):
"""Saves frames of a video using FFMPEG.
Parameters
----------
path : Path
Video path.
This path appended with a `_frames` folder will be used to save the frames.
"""
write_path = path.parents[0] / str(path.name).replace(".mp4", "_frames")
write_path.mkdir()
cmd = [
"ffmpeg",
"-i",
str(path),
"-r",
"1",
str(write_path / "frame_%d.jpg"),
]
subprocess.run(cmd, check=True)
[docs]
def load_grid_plot_data(data_path: Path) -> [Dict, Dict]:
"""Loads the set of data necessary for plotting the grid.
Parameters
----------
data_path : Path
Data path where the pose3d and inverse kinematics are saved.
Returns
-------
Tuple
Returns joint angles (head and leg) and aligned pose as a tuple.
"""
if (data_path / "body_joint_angles.pkl").is_file():
joint_angles = load_file(data_path / "body_joint_angles.pkl")
else:
head_joint_angles = load_file(data_path / "head_joint_angles.pkl")
leg_joint_angles = (
load_file(data_path / "leg_joint_angles.pkl")
if (data_path / "leg_joint_angles.pkl").is_file()
else {}
)
joint_angles = {**head_joint_angles, **leg_joint_angles}
aligned_pose = load_file(data_path / "pose3d_aligned.pkl")
return joint_angles, aligned_pose
[docs]
def animate_3d_points(
points3d: Dict[str, np.ndarray],
# key_points: Dict[str, Tuple[np.ndarray, str]],
export_path: Path,
points3d_second: Dict[str, np.ndarray] = None,
# key_points_second: Dict[str, Tuple[np.ndarray, str]] = None,
fps: int = 100,
frame_no: int = 1000,
format_video: str = "mp4",
elev: int = 10,
azim: int = 90,
title: str = "",
marker_types: Dict[str, str] = None,
) -> None:
"""Makes an animation of 3D pose.
This code is intended for animating the raw 3D pose and
forward kinematics from seqikpy.
Parameters
----------
points3d : Dict[str, np.ndarray]
Dictionary containing the 3D pose,
usually this is the raw 3D pose.
export_path : Path
Path where the animation will be saved.
points3d_second : Dict[str, np.ndarray], optional
Dictionary containing the 3D pose
usually this is the forward kinematics, by default None
fps : int, optional
Frames per second, by default 100
frame_no : int, optional
Maximum number of frames of the animation,
by default 1000
format_video : str, optional
Video format, by default "mp4"
elev : int, optional
Elevation of the point of view, by default 10
azim : int, optional
Azimuth of the point of view, by default 90
title : str, optional
Title of the video, by default ""
marker_types : Dict[str, str], optional
Marker types for each key point, by default None
"""
# Dark background
plt.rcParams.update(
{
"axes.facecolor": "black",
"axes.edgecolor": "black",
"axes.labelcolor": "white",
"xtick.color": "white",
"ytick.color": "white",
"grid.color": "lightgray",
"figure.facecolor": "black",
"figure.edgecolor": "black",
"savefig.facecolor": "black",
"savefig.edgecolor": "black",
}
)
if marker_types is None:
marker_types = {
"R_head": "o",
"L_head": "o",
"Neck": "x",
}
fig = plt.figure()
ax3d = fig.add_subplot(projection="3d")
ax3d.view_init(azim=azim, elev=elev)
# First remove fill
ax3d.xaxis.pane.fill = False
ax3d.yaxis.pane.fill = False
ax3d.zaxis.pane.fill = True
# Now set color to white (or whatever is "invisible")
ax3d.xaxis.pane.set_edgecolor("black")
ax3d.yaxis.pane.set_edgecolor("black")
ax3d.zaxis.pane.set_edgecolor("black")
color_map_right = generate_color_map(cmap="Reds", n=len(points3d))
color_map_left = generate_color_map(cmap="Blues", n=len(points3d))
color_map_scatter = generate_color_map(cmap="RdBu", n=len(points3d))
i, j, k = 0, 0, 0
line_data = []
line_data_second = []
for kp, points3d_array in points3d.items():
order = points3d_array.shape[1]
if order > 3:
if "L" in kp:
color = color_map_left[j]
j += 1
else:
color = color_map_right[k]
k += 1
line_data.append(
ax3d.plot(
points3d_array[0, :, 0],
points3d_array[0, :, 1],
points3d_array[0, :, 2],
label=kp,
linestyle="solid",
linewidth=4,
color=color,
alpha=0.85,
)[0]
)
else:
line_data.append(
ax3d.plot(
points3d_array[0, :, 0],
points3d_array[0, :, 1],
points3d_array[0, :, 2],
lw=2.5,
label=kp,
marker=marker_types[kp],
markersize=9,
color=color_map_scatter[i],
alpha=0.7,
)[0]
)
i += 1
i, j, k = 0, 0, 0
if points3d_second is not None:
for kp, points3d_second_array in points3d_second.items():
order = points3d_second_array.shape[1]
if order > 4:
if "L" in kp:
color = color_map_left[j]
j += 1
else:
color = color_map_right[k]
k += 1
line_data_second.append(
ax3d.plot(
points3d_second_array[0, :, 0],
points3d_second_array[0, :, 1],
points3d_second_array[0, :, 2],
label=kp,
linestyle="--",
linewidth=4,
color=color,
alpha=0.85,
)[0]
)
else:
line_data_second.append(
ax3d.plot(
points3d_second_array[0, :, 0],
points3d_second_array[0, :, 1],
points3d_second_array[0, :, 2],
lw=2.5,
label=kp,
marker=marker_types[kp],
markersize=9,
color=color_map_scatter[j],
alpha=0.7,
)[0]
)
# Setting the axes properties
# ax3d.set_xlim((-0.0, 1.5))
# ax3d.set_ylim((-1., 1.))
# ax3d.set_zlim((0.2, 1.8))
ax3d.set_xticks([])
ax3d.set_yticks([])
ax3d.set_zticks([])
ax3d.tick_params(axis="x", color="black")
ax3d.tick_params(axis="y", color="black")
ax3d.tick_params(axis="z", color="black")
# ax3d.set_axis_off("z")
ax3d.set_xticklabels([])
ax3d.set_yticklabels([])
ax3d.set_zticklabels([])
# ax3d.set_xlabel("x")
# ax3d.set_ylabel("y")
# ax3d.set_zlabel("z")
ax3d.set_title(title, loc="center")
# ax3d.legend(bbox_to_anchor=(1.2, 0.9), frameon=False)
def update(frame, lines, points3d, lines_second, points3d_second):
i = 0
for kp, points3d_array in points3d.items():
lines[i].set_data(points3d_array[frame, :, 0], points3d_array[frame, :, 1])
lines[i].set_3d_properties(points3d_array[frame, :, 2])
i += 1
if lines_second:
j = 0
for kp, points3d_second_array in points3d_second.items():
lines_second[j].set_data(
points3d_second_array[frame, :, 0],
points3d_second_array[frame, :, 1],
)
lines_second[j].set_3d_properties(points3d_second_array[frame, :, 2])
j += 1
# Creating the Animation object
line_ani = animation.FuncAnimation(
fig,
update,
frame_no,
fargs=(line_data, points3d, line_data_second, points3d_second),
interval=10,
blit=False,
)
logger.info("Making animation...")
export_path = str(export_path)
export_path += (
f".{format_video}" if not export_path.endswith((".mp4", ".avi", ".mov")) else ""
)
line_ani.save(export_path, fps=fps, dpi=300)
logger.info(f"Animation is saved at {export_path}")
[docs]
def plot_3d_points(
ax3d, points3d, export_path=None, t=0, marker_types=None, line_style="solid"
):
"""Plots 3D points at time t."""
if marker_types is None:
marker_types = {
"R_head": "o",
"L_head": "o",
"Neck": "x",
}
color_map_right = generate_color_map(
cmap="Reds", n=len([kp for kp in points3d if "R" in kp]) + 1
)
color_map_left = generate_color_map(
cmap="Blues", n=len([kp for kp in points3d if "L" in kp]) + 1
)
i, j = 1, 1
for kp, points3d_array in points3d.items():
order = points3d_array.shape[1]
if "R" in kp:
color = color_map_right[i]
i += 1
elif "L" in kp:
color = color_map_left[j]
j += 1
else:
color = "lightgrey"
if order > 3:
ax3d.plot(
points3d_array[t, :, 0],
points3d_array[t, :, 1],
points3d_array[t, :, 2],
label=kp,
linestyle=line_style,
linewidth=1.7,
color=color,
)
else:
ax3d.plot(
points3d_array[t, :, 0],
points3d_array[t, :, 1],
points3d_array[t, :, 2],
label=kp,
marker=marker_types[kp],
markersize=4.5,
color=color,
)
if export_path is not None:
plt.savefig(export_path, bbox_inches="tight")
[docs]
def plot_trailing_kp(
ax3d,
points3d,
segments_to_plot,
export_path=None,
t=0,
trail=5,
marker_type="x",
):
"""Plots the traces of key points from t-trail to t."""
color_map_right = generate_color_map(cmap="Reds", n=len(points3d) + 1)
color_map_left = generate_color_map(cmap="Blues", n=len(points3d) + 1)
i, j = 1, 1
for kp, ind in segments_to_plot.items():
if "R" in kp:
color = color_map_right[i]
i += 1
elif "L" in kp:
color = color_map_left[j]
j += 1
else:
color = "grey"
ax3d.scatter(
points3d[kp][max(0, t - trail) : t, ind, 0],
points3d[kp][max(0, t - trail) : t, ind, 1],
points3d[kp][max(0, t - trail) : t, ind, 2],
label=kp,
marker=marker_type,
# markersize=9,
color=color,
)
if export_path is not None:
plt.savefig(export_path, bbox_inches="tight")
[docs]
def plot_joint_angle(
ax: plt.Axes,
kinematics_data: Dict[str, np.ndarray],
angles_to_plot: List,
degrees: bool = True,
until_t: int = -1,
stim_lines: List[int] = None,
show_legend: bool = True,
export_path: Path = None,
):
"""Plot joint angles from a given kinematics data.
Parameters
----------
ax : plt.Axes
Axis where the plot will be displayed.
kinematics_data : Dict[str, np.ndarray]
Dictionary containing the kienmatics, pose and angle
angles_to_plot : List
Angles to plot. Exact column name should be given.
degrees : bool, optional
Convert to degrees, by default True
until_t : int, optional
Plots the angles until t^th frame, by default -1
stim_lines : List[int], optional
Plots vertical lines in indicated locations,
by default None
show_legend : bool, optional
Shows legend, by default True
export_path : Path, optional
Path where the plot will be saved, by default None
"""
colors = generate_color_map(cmap="Set2", n=len(angles_to_plot))
for i, joint_name in enumerate(angles_to_plot):
joint_angles = kinematics_data[joint_name]
label = " ".join((joint_name.split("_")[-2], joint_name.split("_")[-1]))
if label in [
"pitch R",
"pitch L",
"yaw R",
"yaw L",
"roll R",
"roll L",
]:
label = "ant. " + label
convert2deg = 180 / np.pi if degrees else 1
ax.plot(
np.array(joint_angles[:until_t]) * convert2deg,
"o-",
ms=4.5,
markevery=[-1],
label=label,
color=colors[i],
)
if stim_lines is not None:
ax.vlines(stim_lines, -200, 200, "red", lw=0.5)
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
if show_legend:
ax.legend(bbox_to_anchor=(1.2, 1), frameon=False, borderaxespad=0.0)
if export_path is not None:
plt.savefig(export_path, bbox_inches="tight")
[docs]
def plot_grid(
img_front: np.ndarray,
img_side: np.ndarray,
aligned_pose: Dict[str, np.ndarray],
joint_angles: Dict[str, np.ndarray],
leg_angles_to_plot: List[str],
head_angles_to_plot: List[str],
t: int,
t_start: int,
t_end: int,
t_interval: int = 20,
fps: int = 100,
trail: int = 30,
key_points_to_trail: Optional[Dict[str, Iterable]] = None,
marker_types_3d: Optional[Dict[str, str]] = None,
marker_trail: Optional[str] = None,
stim_lines: Optional[List[int]] = None,
export_path: Optional[Path] = None,
**kwargs,
):
"""
Plots an instance of the animal recording, 3D pose,
head and leg joint angles in a grid layout. This code
is intended to use in a for loop to plot and save all
the frames from `t_start` to `t_end` to make a video afterwards.
NOTE: This function is intended to plot leg and head joint angles together,
it will not work if any of these data is missing.
Parameters
----------
img_front : Path
Image of the fly at frame t on the front camera.
img_side : Path
Image of the fly at frame t on the side camera.
aligned_pose : Dict[str, np.ndarray]
Aligned 3D pose.
joint_angles : Dict[str, np.ndarray]
Joint angles.
leg_angles_to_plot : List[str]
List containing leg joint angle names without the side.
Examples
>>> leg_joint_angles = [
"ThC_yaw",
"ThC_pitch",
"ThC_roll",
"CTr_pitch",
"CTr_roll",
"FTi_pitch",
"TiTa_pitch",
]
head_angles_to_plot : List[str]
List containing exact names of head joint angle names.
t : int
Frame number t.
t_start : int
Start of the time series, i.e., joint angles.
t_end : int
End of the time series, i.e., joint angles.
t_interval : int, optional
Interval of frame numbers in between x ticks, by default 20
fps : int, optional
Frames per second, by default 100
trail : int, optional
Number of previous frames where the key point will be visible,
by default 30
marker_types_3d : Dict[str, Tuple[np.ndarray, str]]
Dictionary mapping key points names to marker styles.
marker_trail : Dict[str, Tuple[np.ndarray, str]]
Dictionary mapping key points names to their indices
and line styles for trailing key points.
stim_lines : List[int], optional
Stimulation indicators, by default None
export_path : Path, optional
Path where the plot will be saved, by default None
Returns
-------
Fig
Figure.
"""
plot_right_leg = kwargs.pop("plot_right_leg", True)
plot_left_leg = kwargs.pop("plot_left_leg", True)
plot_head = kwargs.pop("plot_head", True)
azim = kwargs.pop("azim", 7)
assert (
t_start <= t <= t_end
), "t_start should be smaller than t_end, t should be in between"
# import pylustrator
# pylustrator.start()
plt.style.use("dark_background")
fig = plt.figure(figsize=(14, 6), dpi=120)
gs = GridSpec(3, 4, figure=fig)
# 7cam recording
ax_img_side = fig.add_subplot(gs[0, :2])
ax_img_front = fig.add_subplot(gs[1, :2])
# 3D pose
ax1 = fig.add_subplot(gs[2, :2], projection="3d")
# head, right leg, left leg joint angles
ax2 = fig.add_subplot(gs[0, 2:])
ax3 = fig.add_subplot(gs[1, 2:])
ax4 = fig.add_subplot(gs[2, 2:])
# load the image
# img = cv2.imread(str(img_path / f"frame_{t}.jpg"), 0)
ax_img_side.imshow(img_side, vmin=0, vmax=255, cmap="gray")
ax_img_front.imshow(img_front, vmin=0, vmax=255, cmap="gray")
plot_3d_points(ax1, aligned_pose, marker_types=marker_types_3d, t=t)
if key_points_to_trail is not None:
plot_trailing_kp(
ax1,
aligned_pose,
key_points_to_trail,
marker_type=marker_trail,
trail=trail,
t=t,
)
if plot_head:
plot_joint_angle(
ax2,
joint_angles,
angles_to_plot=head_angles_to_plot,
until_t=t,
stim_lines=stim_lines,
)
if plot_left_leg:
plot_joint_angle(
ax3,
joint_angles,
angles_to_plot=[f"Angle_LF_{ja}" for ja in leg_angles_to_plot],
until_t=t,
stim_lines=stim_lines,
)
if plot_right_leg:
plot_joint_angle(
ax4,
joint_angles,
angles_to_plot=[f"Angle_RF_{ja}" for ja in leg_angles_to_plot],
until_t=t,
show_legend=False,
stim_lines=stim_lines,
)
ax_img_side.axis("off")
ax_img_front.axis("off")
# ax1 properties
ax1.view_init(azim=azim, elev=10)
ax1.set_xlim3d([-0.45, 0.45])
ax1.set_ylim3d([-0.6, 0.6])
ax1.set_zlim3d([0.42, 1.1])
ax1.set_xticks([])
ax1.set_yticks([])
ax1.set_zticks([])
ax1.axis("off")
# ax2 properties
ax2.set_xlim((t_start, t_end))
ax2.spines["bottom"].set_visible(False)
ax2.spines["top"].set_visible(False)
ax2.spines["right"].set_visible(False)
ax2.set_xlim((t_start, t_end))
ax2.set_ylim((-90, 70))
ax2.set_xticks([])
ax2.set_yticks(ticks=[-90, 0, 70])
ax2.set_yticklabels(labels=[-90, 0, 70])
# ax3 properties
ax3.set_xlim((t_start, t_end))
ax3.spines["bottom"].set_visible(False)
ax3.spines["top"].set_visible(False)
ax3.spines["right"].set_visible(False)
ax3.set_ylim((-160, 160))
ax3.set_xticks([])
ax3.set_yticks(ticks=[-160, 0, 160])
ax3.set_yticklabels(labels=[-160, 0, 160])
# ax4 properties
ax4.set_xlim((t_start, t_end))
ax4.set_ylim((-160, 160))
ax4.spines["top"].set_visible(False)
ax4.spines["right"].set_visible(False)
ax4.set_yticks(ticks=[-160, 0, 160])
ax4.set_yticklabels(labels=[-160, 0, 160])
ax4.set_xticks(ticks=np.arange(t_start, t_end + t_interval, t_interval))
ax4.set_xticklabels(labels=np.arange(t_start, t_end + t_interval, t_interval) / fps)
ax4.set_xlabel("Time (s)")
# #% start: automatic generated code from pylustrator
fig.set_size_inches(22.710000 / 2.54, 11.430000 / 2.54, forward=True)
fig.text(
0.3865,
0.9184,
"Head and antennae joint angles (deg)",
transform=fig.transFigure,
)
fig.text(
0.3865,
0.6346,
"Left front leg joint angles (deg)",
transform=fig.transFigure,
) # id=fig.texts[0].new
fig.text(
0.3865,
0.3502,
"Right front leg joint angles (deg)",
transform=fig.transFigure,
) # id=fig.texts[1].new
# #% start: automatic generated code from pylustrator
fig.set_size_inches(21.890000 / 2.54, 11.420000 / 2.54, forward=True)
fig.axes[0].set_position([0.049668, 0.665633, 0.277935, 0.266469])
fig.axes[1].set_position([0.049668, 0.382601, 0.277935, 0.266469])
fig.axes[2].set(position=[0.1194, 0.06658, 0.1318, 0.2619])
fig.axes[2].set_position([0.123600, 0.069459, 0.136149, 0.261065])
fig.axes[3].set_position([0.400791, 0.677390, 0.383414, 0.225807])
fig.axes[4].set_position([0.400791, 0.396851, 0.383414, 0.225807])
fig.axes[5].set_position([0.400791, 0.112723, 0.383414, 0.225807])
fig.texts[0].set_position([0.399756, 0.918650])
fig.texts[1].set_position([0.399756, 0.635718])
fig.texts[2].set_position([0.399756, 0.352188])
fig.axes[3].legend(loc=(1.068, -0.1324), frameon=False)
fig.axes[4].legend(loc=(1.064, -0.569), frameon=False)
# % end: automatic generated code from pylustrator
if export_path is not None:
fig.savefig(export_path, bbox_inches="tight")
print(f"Figure saved at {str(export_path)}")
return fig
[docs]
def plot_grid_generator(
fly_frames_front: np.ndarray,
fly_frames_side: np.ndarray,
aligned_pose: Dict[str, np.ndarray],
joint_angles: Dict[str, np.ndarray],
leg_angles_to_plot: List[str],
head_angles_to_plot: List[str],
t_start: int,
t_end: int,
t_interval: int = 20,
fps: int = 100,
trail: int = 30,
key_points_to_trail: Optional[List[str]] = None,
marker_types_3d: Dict[str, str] = None,
marker_trail: Optional[str] = "x",
stim_lines: List[int] = None,
export_path: Path = None,
**kwargs,
):
"""Generator for plotting grid."""
for t, (fly_img_front, fly_img_side) in enumerate(
zip(fly_frames_front, fly_frames_side)
):
fig = plot_grid(
img_front=fly_img_front,
img_side=fly_img_side,
aligned_pose=aligned_pose,
joint_angles=joint_angles,
leg_angles_to_plot=leg_angles_to_plot,
head_angles_to_plot=head_angles_to_plot,
marker_types_3d=marker_types_3d,
marker_trail=marker_trail,
key_points_to_trail=key_points_to_trail,
t=t + t_start,
t_start=t_start,
t_end=t_end,
t_interval=t_interval,
fps=fps,
trail=trail,
stim_lines=stim_lines,
export_path=export_path,
**kwargs,
)
plt.close(fig)
yield fig_to_array(fig)
[docs]
def fig_to_array(fig):
"""Converts a matplotlib figure into an array."""
canvas = FigureCanvas(fig)
canvas.draw()
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
return data