""" Functions for plotting and animation. """importloggingimportsubprocessimportwarningsimportitertoolsfrompathlibimportPathfromtypingimportTuple,List,Dict,Optional,Iterablefromtqdmimporttqdmimportcv2importnumpyasnpfrommatplotlibimportanimationfrommatplotlib.gridspecimportGridSpecimportmatplotlib.pyplotaspltfrommatplotlib.backends.backend_aggimportFigureCanvasAggasFigureCanvasfromseqikpy.utilsimportload_file# Ignore the warningswarnings.filterwarnings("ignore")# Change the logging level herelogging.basicConfig(format=" %(asctime)s - %(levelname)s- %(message)s")# Get the logger of the modulelogger=logging.getLogger(__name__)logger.setLevel(logging.INFO)
[docs]defgenerate_color_map(cmap:str,n:int)->List:"""Generates a list of colors from a given colormap."""cmap=plt.get_cmap(cmap)colors=cmap(np.linspace(0,1,n))returncolors
[docs]defget_video_writer(video_path,fps,output_shape):"""Initialize and return a VideoWriter object."""fourcc=cv2.VideoWriter_fourcc(*"mp4v")returncv2.VideoWriter(video_path,fourcc,fps,output_shape[::-1])
[docs]defresize_frame(original_size,target_size):""" Resize a frame to the target size while preserving the width length ratio. """iftarget_size[0]==-1andtarget_size[1]==-1:new_size=original_sizeeliftarget_size[0]==-1:ratio=original_size[0]/original_size[1]new_size=(int(target_size[1]*ratio),target_size[1])eliftarget_size[1]==-1:ratio=original_size[1]/original_size[0]new_size=(target_size[0],int(target_size[0]*ratio))else:new_size=target_sizeifnew_size[0]>original_size[0]andnew_size[1]>original_size[1]:returnoriginal_sizereturnnew_size
[docs]defresize_rgb_frame(frame,output_shape):"""resize the frame and convert it to RGB"""resized=cv2.resize(frame,output_shape[::-1])returncv2.cvtColor(resized,cv2.COLOR_BGR2RGB)
[docs]defmake_video(video_path:str,frame_generator:Iterable,fps:int,output_shape:Tuple=(-1,2880),n_frames:int=-1,):"""Makes videos from a generator of images."""first_frame=next(frame_generator)frame_generator=itertools.chain([first_frame],frame_generator)adjusted_output_shape=resize_frame(first_frame.shape[:2],output_shape)video_writer=get_video_writer(video_path,fps,adjusted_output_shape)forframe_count,frameintqdm(enumerate(frame_generator)):rgb_frame=resize_rgb_frame(frame,adjusted_output_shape)video_writer.write(rgb_frame)if0<n_frames==frame_count+1:breakvideo_writer.release()logging.info("Video is saved at %s",video_path)
[docs]defvideo_frames_generator(video_path:Path,start_frame:int,end_frame:int,stim_lines:List[int],radius=30,center=(50,50),color=(255,0,0),):"""Returns the frames as a generator in a given interval. Modifies the brightness and contrast of the images. Parameters ---------- video_path : Path Video path. start_frame : int Starting frame. end_frame : int End frame. Yields ------ Frame Generator containing all the frames in the specified interval. """cap=cv2.VideoCapture(str(video_path))cap.set(cv2.CAP_PROP_POS_FRAMES,start_frame)# Define the contrast and brightness valuesalpha=1.15# Contrast controlbeta=-5# Brightness controlfortinrange(start_frame,end_frame):ret,frame=cap.read()adjusted_frame=cv2.convertScaleAbs(frame,alpha=alpha,beta=beta)# if stimulation, add a red dotifstim_lines[0]<=t<=stim_lines[-1]:adjusted_frame=cv2.circle(adjusted_frame,center,radius,color,-1)ifnotret:breakyieldadjusted_framecap.release()
[docs]defget_plot_config(data_path:Path):"""Get experimental conditions from the data path. Data path should look like: "/mnt/nas2/GO/7cam/220810_aJO-GAL4xUAS-CsChr/Fly001/001_RLF/behData/pose-3d" """assertdata_path.parts[-1]=="pose-3d","The data path should end with pose-3d"assertdata_path.parts[-2]=="behData","The data path should contain behData"plot_config={}trial_type=data_path.parts[-3]if"_RLF_coxa"intrial_type:exp_type="coxa"plot_config["plot_head"]=Trueplot_config["plot_right_leg"]=Trueplot_config["plot_left_leg"]=Trueplot_config["azim"]=23elif"_RLF"intrial_type:exp_type="RLF"plot_config["plot_head"]=Trueplot_config["plot_right_leg"]=Falseplot_config["plot_left_leg"]=Falseelif"_LF"intrial_type:exp_type="LF"plot_config["plot_head"]=Trueplot_config["plot_right_leg"]=Trueplot_config["plot_left_leg"]=Falseelif"_RF"intrial_type:exp_type="RF"plot_config["plot_head"]=Trueplot_config["plot_right_leg"]=Falseplot_config["plot_left_leg"]=Trueelse:exp_type="Beh"plot_config["plot_head"]=Trueplot_config["plot_right_leg"]=Trueplot_config["plot_left_leg"]=Truereturnexp_type,plot_config
[docs]defget_frames_from_video_ffmpeg(path):"""Saves frames of a video using FFMPEG. Parameters ---------- path : Path Video path. This path appended with a `_frames` folder will be used to save the frames. """write_path=path.parents[0]/str(path.name).replace(".mp4","_frames")write_path.mkdir()cmd=["ffmpeg","-i",str(path),"-r","1",str(write_path/"frame_%d.jpg"),]subprocess.run(cmd,check=True)
[docs]defload_grid_plot_data(data_path:Path)->[Dict,Dict]:"""Loads the set of data necessary for plotting the grid. Parameters ---------- data_path : Path Data path where the pose3d and inverse kinematics are saved. Returns ------- Tuple Returns joint angles (head and leg) and aligned pose as a tuple. """if(data_path/"body_joint_angles.pkl").is_file():joint_angles=load_file(data_path/"body_joint_angles.pkl")else:head_joint_angles=load_file(data_path/"head_joint_angles.pkl")leg_joint_angles=(load_file(data_path/"leg_joint_angles.pkl")if(data_path/"leg_joint_angles.pkl").is_file()else{})joint_angles={**head_joint_angles,**leg_joint_angles}aligned_pose=load_file(data_path/"pose3d_aligned.pkl")returnjoint_angles,aligned_pose
[docs]defanimate_3d_points(points3d:Dict[str,np.ndarray],# key_points: Dict[str, Tuple[np.ndarray, str]],export_path:Path,points3d_second:Dict[str,np.ndarray]=None,# key_points_second: Dict[str, Tuple[np.ndarray, str]] = None,fps:int=100,frame_no:int=1000,format_video:str="mp4",elev:int=10,azim:int=90,title:str="",marker_types:Dict[str,str]=None,)->None:"""Makes an animation of 3D pose. This code is intended for animating the raw 3D pose and forward kinematics from seqikpy. Parameters ---------- points3d : Dict[str, np.ndarray] Dictionary containing the 3D pose, usually this is the raw 3D pose. export_path : Path Path where the animation will be saved. points3d_second : Dict[str, np.ndarray], optional Dictionary containing the 3D pose usually this is the forward kinematics, by default None fps : int, optional Frames per second, by default 100 frame_no : int, optional Maximum number of frames of the animation, by default 1000 format_video : str, optional Video format, by default "mp4" elev : int, optional Elevation of the point of view, by default 10 azim : int, optional Azimuth of the point of view, by default 90 title : str, optional Title of the video, by default "" marker_types : Dict[str, str], optional Marker types for each key point, by default None """# Dark backgroundplt.rcParams.update({"axes.facecolor":"black","axes.edgecolor":"black","axes.labelcolor":"white","xtick.color":"white","ytick.color":"white","grid.color":"lightgray","figure.facecolor":"black","figure.edgecolor":"black","savefig.facecolor":"black","savefig.edgecolor":"black",})ifmarker_typesisNone:marker_types={"R_head":"o","L_head":"o","Neck":"x",}fig=plt.figure()ax3d=fig.add_subplot(projection="3d")ax3d.view_init(azim=azim,elev=elev)# First remove fillax3d.xaxis.pane.fill=Falseax3d.yaxis.pane.fill=Falseax3d.zaxis.pane.fill=True# Now set color to white (or whatever is "invisible")ax3d.xaxis.pane.set_edgecolor("black")ax3d.yaxis.pane.set_edgecolor("black")ax3d.zaxis.pane.set_edgecolor("black")color_map_right=generate_color_map(cmap="Reds",n=len(points3d))color_map_left=generate_color_map(cmap="Blues",n=len(points3d))color_map_scatter=generate_color_map(cmap="RdBu",n=len(points3d))i,j,k=0,0,0line_data=[]line_data_second=[]forkp,points3d_arrayinpoints3d.items():order=points3d_array.shape[1]iforder>3:if"L"inkp:color=color_map_left[j]j+=1else:color=color_map_right[k]k+=1line_data.append(ax3d.plot(points3d_array[0,:,0],points3d_array[0,:,1],points3d_array[0,:,2],label=kp,linestyle="solid",linewidth=4,color=color,alpha=0.85,)[0])else:line_data.append(ax3d.plot(points3d_array[0,:,0],points3d_array[0,:,1],points3d_array[0,:,2],lw=2.5,label=kp,marker=marker_types[kp],markersize=9,color=color_map_scatter[i],alpha=0.7,)[0])i+=1i,j,k=0,0,0ifpoints3d_secondisnotNone:forkp,points3d_second_arrayinpoints3d_second.items():order=points3d_second_array.shape[1]iforder>4:if"L"inkp:color=color_map_left[j]j+=1else:color=color_map_right[k]k+=1line_data_second.append(ax3d.plot(points3d_second_array[0,:,0],points3d_second_array[0,:,1],points3d_second_array[0,:,2],label=kp,linestyle="--",linewidth=4,color=color,alpha=0.85,)[0])else:line_data_second.append(ax3d.plot(points3d_second_array[0,:,0],points3d_second_array[0,:,1],points3d_second_array[0,:,2],lw=2.5,label=kp,marker=marker_types[kp],markersize=9,color=color_map_scatter[j],alpha=0.7,)[0])# Setting the axes properties# ax3d.set_xlim((-0.0, 1.5))# ax3d.set_ylim((-1., 1.))# ax3d.set_zlim((0.2, 1.8))ax3d.set_xticks([])ax3d.set_yticks([])ax3d.set_zticks([])ax3d.tick_params(axis="x",color="black")ax3d.tick_params(axis="y",color="black")ax3d.tick_params(axis="z",color="black")# ax3d.set_axis_off("z")ax3d.set_xticklabels([])ax3d.set_yticklabels([])ax3d.set_zticklabels([])# ax3d.set_xlabel("x")# ax3d.set_ylabel("y")# ax3d.set_zlabel("z")ax3d.set_title(title,loc="center")# ax3d.legend(bbox_to_anchor=(1.2, 0.9), frameon=False)defupdate(frame,lines,points3d,lines_second,points3d_second):i=0forkp,points3d_arrayinpoints3d.items():lines[i].set_data(points3d_array[frame,:,0],points3d_array[frame,:,1])lines[i].set_3d_properties(points3d_array[frame,:,2])i+=1iflines_second:j=0forkp,points3d_second_arrayinpoints3d_second.items():lines_second[j].set_data(points3d_second_array[frame,:,0],points3d_second_array[frame,:,1],)lines_second[j].set_3d_properties(points3d_second_array[frame,:,2])j+=1# Creating the Animation objectline_ani=animation.FuncAnimation(fig,update,frame_no,fargs=(line_data,points3d,line_data_second,points3d_second),interval=10,blit=False,)logger.info("Making animation...")export_path=str(export_path)export_path+=(f".{format_video}"ifnotexport_path.endswith((".mp4",".avi",".mov"))else"")line_ani.save(export_path,fps=fps,dpi=300)logger.info(f"Animation is saved at {export_path}")
[docs]defplot_3d_points(ax3d,points3d,export_path=None,t=0,marker_types=None,line_style="solid"):"""Plots 3D points at time t."""ifmarker_typesisNone:marker_types={"R_head":"o","L_head":"o","Neck":"x",}color_map_right=generate_color_map(cmap="Reds",n=len([kpforkpinpoints3dif"R"inkp])+1)color_map_left=generate_color_map(cmap="Blues",n=len([kpforkpinpoints3dif"L"inkp])+1)i,j=1,1forkp,points3d_arrayinpoints3d.items():order=points3d_array.shape[1]if"R"inkp:color=color_map_right[i]i+=1elif"L"inkp:color=color_map_left[j]j+=1else:color="lightgrey"iforder>3:ax3d.plot(points3d_array[t,:,0],points3d_array[t,:,1],points3d_array[t,:,2],label=kp,linestyle=line_style,linewidth=1.7,color=color,)else:ax3d.plot(points3d_array[t,:,0],points3d_array[t,:,1],points3d_array[t,:,2],label=kp,marker=marker_types[kp],markersize=4.5,color=color,)ifexport_pathisnotNone:plt.savefig(export_path,bbox_inches="tight")
[docs]defplot_trailing_kp(ax3d,points3d,segments_to_plot,export_path=None,t=0,trail=5,marker_type="x",):"""Plots the traces of key points from t-trail to t."""color_map_right=generate_color_map(cmap="Reds",n=len(points3d)+1)color_map_left=generate_color_map(cmap="Blues",n=len(points3d)+1)i,j=1,1forkp,indinsegments_to_plot.items():if"R"inkp:color=color_map_right[i]i+=1elif"L"inkp:color=color_map_left[j]j+=1else:color="grey"ax3d.scatter(points3d[kp][max(0,t-trail):t,ind,0],points3d[kp][max(0,t-trail):t,ind,1],points3d[kp][max(0,t-trail):t,ind,2],label=kp,marker=marker_type,# markersize=9,color=color,)ifexport_pathisnotNone:plt.savefig(export_path,bbox_inches="tight")
[docs]defplot_joint_angle(ax:plt.Axes,kinematics_data:Dict[str,np.ndarray],angles_to_plot:List,degrees:bool=True,until_t:int=-1,stim_lines:List[int]=None,show_legend:bool=True,export_path:Path=None,):"""Plot joint angles from a given kinematics data. Parameters ---------- ax : plt.Axes Axis where the plot will be displayed. kinematics_data : Dict[str, np.ndarray] Dictionary containing the kienmatics, pose and angle angles_to_plot : List Angles to plot. Exact column name should be given. degrees : bool, optional Convert to degrees, by default True until_t : int, optional Plots the angles until t^th frame, by default -1 stim_lines : List[int], optional Plots vertical lines in indicated locations, by default None show_legend : bool, optional Shows legend, by default True export_path : Path, optional Path where the plot will be saved, by default None """colors=generate_color_map(cmap="Set2",n=len(angles_to_plot))fori,joint_nameinenumerate(angles_to_plot):joint_angles=kinematics_data[joint_name]label=" ".join((joint_name.split("_")[-2],joint_name.split("_")[-1]))iflabelin["pitch R","pitch L","yaw R","yaw L","roll R","roll L",]:label="ant. "+labelconvert2deg=180/np.piifdegreeselse1ax.plot(np.array(joint_angles[:until_t])*convert2deg,"o-",ms=4.5,markevery=[-1],label=label,color=colors[i],)ifstim_linesisnotNone:ax.vlines(stim_lines,-200,200,"red",lw=0.5)ax.spines["top"].set_visible(False)ax.spines["right"].set_visible(False)ifshow_legend:ax.legend(bbox_to_anchor=(1.2,1),frameon=False,borderaxespad=0.0)ifexport_pathisnotNone:plt.savefig(export_path,bbox_inches="tight")
[docs]defplot_grid(img_front:np.ndarray,img_side:np.ndarray,aligned_pose:Dict[str,np.ndarray],joint_angles:Dict[str,np.ndarray],leg_angles_to_plot:List[str],head_angles_to_plot:List[str],t:int,t_start:int,t_end:int,t_interval:int=20,fps:int=100,trail:int=30,key_points_to_trail:Optional[Dict[str,Iterable]]=None,marker_types_3d:Optional[Dict[str,str]]=None,marker_trail:Optional[str]=None,stim_lines:Optional[List[int]]=None,export_path:Optional[Path]=None,**kwargs,):""" Plots an instance of the animal recording, 3D pose, head and leg joint angles in a grid layout. This code is intended to use in a for loop to plot and save all the frames from `t_start` to `t_end` to make a video afterwards. NOTE: This function is intended to plot leg and head joint angles together, it will not work if any of these data is missing. Parameters ---------- img_front : Path Image of the fly at frame t on the front camera. img_side : Path Image of the fly at frame t on the side camera. aligned_pose : Dict[str, np.ndarray] Aligned 3D pose. joint_angles : Dict[str, np.ndarray] Joint angles. leg_angles_to_plot : List[str] List containing leg joint angle names without the side. Examples >>> leg_joint_angles = [ "ThC_yaw", "ThC_pitch", "ThC_roll", "CTr_pitch", "CTr_roll", "FTi_pitch", "TiTa_pitch", ] head_angles_to_plot : List[str] List containing exact names of head joint angle names. t : int Frame number t. t_start : int Start of the time series, i.e., joint angles. t_end : int End of the time series, i.e., joint angles. t_interval : int, optional Interval of frame numbers in between x ticks, by default 20 fps : int, optional Frames per second, by default 100 trail : int, optional Number of previous frames where the key point will be visible, by default 30 marker_types_3d : Dict[str, Tuple[np.ndarray, str]] Dictionary mapping key points names to marker styles. marker_trail : Dict[str, Tuple[np.ndarray, str]] Dictionary mapping key points names to their indices and line styles for trailing key points. stim_lines : List[int], optional Stimulation indicators, by default None export_path : Path, optional Path where the plot will be saved, by default None Returns ------- Fig Figure. """plot_right_leg=kwargs.pop("plot_right_leg",True)plot_left_leg=kwargs.pop("plot_left_leg",True)plot_head=kwargs.pop("plot_head",True)azim=kwargs.pop("azim",7)assert(t_start<=t<=t_end),"t_start should be smaller than t_end, t should be in between"# import pylustrator# pylustrator.start()plt.style.use("dark_background")fig=plt.figure(figsize=(14,6),dpi=120)gs=GridSpec(3,4,figure=fig)# 7cam recordingax_img_side=fig.add_subplot(gs[0,:2])ax_img_front=fig.add_subplot(gs[1,:2])# 3D poseax1=fig.add_subplot(gs[2,:2],projection="3d")# head, right leg, left leg joint anglesax2=fig.add_subplot(gs[0,2:])ax3=fig.add_subplot(gs[1,2:])ax4=fig.add_subplot(gs[2,2:])# load the image# img = cv2.imread(str(img_path / f"frame_{t}.jpg"), 0)ax_img_side.imshow(img_side,vmin=0,vmax=255,cmap="gray")ax_img_front.imshow(img_front,vmin=0,vmax=255,cmap="gray")plot_3d_points(ax1,aligned_pose,marker_types=marker_types_3d,t=t)ifkey_points_to_trailisnotNone:plot_trailing_kp(ax1,aligned_pose,key_points_to_trail,marker_type=marker_trail,trail=trail,t=t,)ifplot_head:plot_joint_angle(ax2,joint_angles,angles_to_plot=head_angles_to_plot,until_t=t,stim_lines=stim_lines,)ifplot_left_leg:plot_joint_angle(ax3,joint_angles,angles_to_plot=[f"Angle_LF_{ja}"forjainleg_angles_to_plot],until_t=t,stim_lines=stim_lines,)ifplot_right_leg:plot_joint_angle(ax4,joint_angles,angles_to_plot=[f"Angle_RF_{ja}"forjainleg_angles_to_plot],until_t=t,show_legend=False,stim_lines=stim_lines,)ax_img_side.axis("off")ax_img_front.axis("off")# ax1 propertiesax1.view_init(azim=azim,elev=10)ax1.set_xlim3d([-0.45,0.45])ax1.set_ylim3d([-0.6,0.6])ax1.set_zlim3d([0.42,1.1])ax1.set_xticks([])ax1.set_yticks([])ax1.set_zticks([])ax1.axis("off")# ax2 propertiesax2.set_xlim((t_start,t_end))ax2.spines["bottom"].set_visible(False)ax2.spines["top"].set_visible(False)ax2.spines["right"].set_visible(False)ax2.set_xlim((t_start,t_end))ax2.set_ylim((-90,70))ax2.set_xticks([])ax2.set_yticks(ticks=[-90,0,70])ax2.set_yticklabels(labels=[-90,0,70])# ax3 propertiesax3.set_xlim((t_start,t_end))ax3.spines["bottom"].set_visible(False)ax3.spines["top"].set_visible(False)ax3.spines["right"].set_visible(False)ax3.set_ylim((-160,160))ax3.set_xticks([])ax3.set_yticks(ticks=[-160,0,160])ax3.set_yticklabels(labels=[-160,0,160])# ax4 propertiesax4.set_xlim((t_start,t_end))ax4.set_ylim((-160,160))ax4.spines["top"].set_visible(False)ax4.spines["right"].set_visible(False)ax4.set_yticks(ticks=[-160,0,160])ax4.set_yticklabels(labels=[-160,0,160])ax4.set_xticks(ticks=np.arange(t_start,t_end+t_interval,t_interval))ax4.set_xticklabels(labels=np.arange(t_start,t_end+t_interval,t_interval)/fps)ax4.set_xlabel("Time (s)")# #% start: automatic generated code from pylustratorfig.set_size_inches(22.710000/2.54,11.430000/2.54,forward=True)fig.text(0.3865,0.9184,"Head and antennae joint angles (deg)",transform=fig.transFigure,)fig.text(0.3865,0.6346,"Left front leg joint angles (deg)",transform=fig.transFigure,)# id=fig.texts[0].newfig.text(0.3865,0.3502,"Right front leg joint angles (deg)",transform=fig.transFigure,)# id=fig.texts[1].new# #% start: automatic generated code from pylustratorfig.set_size_inches(21.890000/2.54,11.420000/2.54,forward=True)fig.axes[0].set_position([0.049668,0.665633,0.277935,0.266469])fig.axes[1].set_position([0.049668,0.382601,0.277935,0.266469])fig.axes[2].set(position=[0.1194,0.06658,0.1318,0.2619])fig.axes[2].set_position([0.123600,0.069459,0.136149,0.261065])fig.axes[3].set_position([0.400791,0.677390,0.383414,0.225807])fig.axes[4].set_position([0.400791,0.396851,0.383414,0.225807])fig.axes[5].set_position([0.400791,0.112723,0.383414,0.225807])fig.texts[0].set_position([0.399756,0.918650])fig.texts[1].set_position([0.399756,0.635718])fig.texts[2].set_position([0.399756,0.352188])fig.axes[3].legend(loc=(1.068,-0.1324),frameon=False)fig.axes[4].legend(loc=(1.064,-0.569),frameon=False)# % end: automatic generated code from pylustratorifexport_pathisnotNone:fig.savefig(export_path,bbox_inches="tight")print(f"Figure saved at {str(export_path)}")returnfig
[docs]defplot_grid_generator(fly_frames_front:np.ndarray,fly_frames_side:np.ndarray,aligned_pose:Dict[str,np.ndarray],joint_angles:Dict[str,np.ndarray],leg_angles_to_plot:List[str],head_angles_to_plot:List[str],t_start:int,t_end:int,t_interval:int=20,fps:int=100,trail:int=30,key_points_to_trail:Optional[List[str]]=None,marker_types_3d:Dict[str,str]=None,marker_trail:Optional[str]="x",stim_lines:List[int]=None,export_path:Path=None,**kwargs,):"""Generator for plotting grid."""fort,(fly_img_front,fly_img_side)inenumerate(zip(fly_frames_front,fly_frames_side)):fig=plot_grid(img_front=fly_img_front,img_side=fly_img_side,aligned_pose=aligned_pose,joint_angles=joint_angles,leg_angles_to_plot=leg_angles_to_plot,head_angles_to_plot=head_angles_to_plot,marker_types_3d=marker_types_3d,marker_trail=marker_trail,key_points_to_trail=key_points_to_trail,t=t+t_start,t_start=t_start,t_end=t_end,t_interval=t_interval,fps=fps,trail=trail,stim_lines=stim_lines,export_path=export_path,**kwargs,)plt.close(fig)yieldfig_to_array(fig)
[docs]deffig_to_array(fig):"""Converts a matplotlib figure into an array."""canvas=FigureCanvas(fig)canvas.draw()data=np.frombuffer(fig.canvas.tostring_rgb(),dtype=np.uint8)data=data.reshape(fig.canvas.get_width_height()[::-1]+(3,))returndata