Replaying experimental recordings and inferring dynamical quantities¶
In this tutorial, we replay a snippet of experimentally recorded fly walking by matching recorded joint angle sequences with position actuators. In other words, the simulation will compare current joint angles with target (i.e., recorded) joint angles, and apply a force to correct for the difference.
We start by loading a snippet of experimentally recorded kinematics. A MotionSnippet helper class is implemented for this purpose.
import numpy as np
from flygym_demo.spotlight_data import MotionSnippet
snippet = MotionSnippet()
This recording was collected using the Spotlight system, and the 3D pose was inferred using PoseForge. This workflow is described in Wang-Chen et al., 2026.
Let's inspect what this dataset contains:
print("Experimental data and metadata:")
print(f" legs: {snippet.legs}")
print(f" dofs_per_leg: {snippet.dofs_per_leg}")
print(f" data_fps: {snippet.data_fps}")
print(f" experiment_trial: {snippet.experiment_trial}")
print(f" framerange_in_raw_recording: {snippet.framerange_in_raw_recording}")
print(f" joint_angles: shape={snippet.joint_angles.shape}, dtype={snippet.joint_angles.dtype}")
print(f" fwdkin_egoxyz: shape={snippet.fwdkin_egoxyz.shape}, dtype={snippet.fwdkin_egoxyz.dtype}")
print(f" rawpred_egoxyz: shape={snippet.rawpred_egoxyz.shape}, dtype={snippet.rawpred_egoxyz.dtype}")
Experimental data and metadata:
legs: ['lf', 'lm', 'lh', 'rf', 'rm', 'rh']
dofs_per_leg: [('thorax', 'coxa', 'pitch'), ('thorax', 'coxa', 'roll'), ('thorax', 'coxa', 'yaw'), ('coxa', 'trochanterfemur', 'pitch'), ('coxa', 'trochanterfemur', 'roll'), ('trochanterfemur', 'tibia', 'pitch'), ('tibia', 'tarsus1', 'pitch')]
data_fps: 330
experiment_trial: 20250613-fly1b-012
framerange_in_raw_recording: [1033, 1693]
joint_angles: shape=(660, 6, 7), dtype=float32
fwdkin_egoxyz: shape=(660, 30, 3), dtype=float32
rawpred_egoxyz: shape=(660, 30, 3), dtype=float32
As a sanity check, we can plot the egocentric x (fore/aft), y (left/right), and z (height) positions of the claw of the left front leg over time.
It is worth noting that the dataset contains two versions of the xyz keypoint positions:
- The raw coordinates predicted by the PoseForge model (
rawpred_egoxyz). These are outputs of a black-box deep learning model and do not explicitly take anatomical contraints into account. We will plot these iasn black dotted lines. - The anatomically constrained coordinates following inverse and forward kinematics (
fwdkin_egoxyz). That is, joint angles in a purely geometic model (no physics simulation) are fitted to minimize the discrepancy between the geometric model's keypoint positions and the raw model outputs. As such, these fitted keypoint positions are constrained by (a) leg segment lengths and (b) the available joint rotation axes (e.g., only two axes of rotation are possible at the femur-tibia joint, and only one is available at the tibia-tarsus joint). The process of fitting joint angles is called inverse kinematics, and the process of computing the resulting constrained xyz coordinates is called forward kinematics. This dataset uses SeqIKPy (Özdil et al., 2026) for inverse and forward kinematics, but many other options are available (e.g., DeepLabCut + Anipose, and DeepFly3D). We will plot these in blue.
import matplotlib.pyplot as plt
leg_to_plot = "lf"
leg_to_plot_idx = snippet.legs.index(leg_to_plot)
n_kpts_per_leg = 5
kpt_idx_to_visualize = leg_to_plot_idx * n_kpts_per_leg + n_kpts_per_leg - 1
expdata_timegrid = np.arange(snippet.joint_angles.shape[0]) / snippet.data_fps
fig, axes = plt.subplots(3, 1, figsize=(8, 3.5), tight_layout=True, sharex=True)
for dim_idx, dim_name in enumerate("xyz"):
ax = axes[dim_idx]
ts_fwdkin = snippet.fwdkin_egoxyz[:, kpt_idx_to_visualize, dim_idx]
ts_rawpred = snippet.rawpred_egoxyz[:, kpt_idx_to_visualize, dim_idx]
ax.plot(
expdata_timegrid,
ts_fwdkin,
label="Forward kinematics",
color="C0",
)
ax.plot(
expdata_timegrid,
ts_rawpred,
label="Raw prediction",
linestyle=":",
color="black",
)
# ignorable plotting code below
_center = np.percentile(ts_rawpred, [10, 90]).mean()
_plotspan = 2 # mm
ax.set_ylim(_center - 0.5 * _plotspan, _center + 0.5 * _plotspan)
ax.set_ylabel(f"Distance\n(mm)")
ax.set_title(dim_name, fontsize="medium")
if dim_name == "z":
ax.set_xlabel("Time (s)")
if dim_name == "x":
ax.legend(bbox_to_anchor=(1.04, 0.5), loc="center left")
fig.suptitle(f"{leg_to_plot.upper()} leg claw position (egocentric Cartesian frame)")
Text(0.5, 0.98, 'LF leg claw position (egocentric Cartesian frame)')
We can also plot all legs in 3D, but only for one frame:
frame_to_plot = 0
ax = plt.figure(figsize=(5, 3), tight_layout=True).add_subplot(projection="3d")
for leg_idx, leg_name in enumerate(snippet.legs):
kpt_indices = slice(leg_idx * n_kpts_per_leg, (leg_idx + 1) * n_kpts_per_leg)
fwdkin_egoxyz = snippet.fwdkin_egoxyz[frame_to_plot, kpt_indices, :]
rawpred_egoxyz = snippet.rawpred_egoxyz[frame_to_plot, kpt_indices, :]
ax.plot(
fwdkin_egoxyz[:, 0],
fwdkin_egoxyz[:, 1],
fwdkin_egoxyz[:, 2],
label=f"{leg_name.upper()} leg",
color=f"C{leg_idx}",
)
ax.plot(
rawpred_egoxyz[:, 0],
rawpred_egoxyz[:, 1],
rawpred_egoxyz[:, 2],
label="Raw prediction" if leg_idx == 5 else None,
linestyle=":",
color="black",
)
ax.set_xlabel("x (mm)")
ax.set_ylabel("y (mm)")
ax.set_zlabel("z (mm)")
ax.set_aspect("equal")
ax.legend(loc="center left", bbox_to_anchor=(1.3, 0.5))
<matplotlib.legend.Legend at 0x712eb43bb890>
Now, let's visualize the joint angles time series fitted through inverse kinematics:
fig, axes = plt.subplots(7, 1, figsize=(7, 6.5), tight_layout=True, sharex=True)
for dof_idx, (parent_link, child_link, axis) in enumerate(snippet.dofs_per_leg):
ax = axes[dof_idx]
ts_angles = np.rad2deg(snippet.joint_angles[:, leg_to_plot_idx, dof_idx])
ax.plot(expdata_timegrid, ts_angles)
_center = np.percentile(ts_angles, [10, 90]).mean()
_plotspan = 120 # deg
ax.set_ylim(_center - 0.5 * _plotspan, _center + 0.5 * _plotspan)
ax.set_ylabel("Angle\n(°)")
ax.set_title(f"{parent_link}-{child_link} {axis}", fontsize="medium")
if dof_idx == 6:
ax.set_xlabel("Time (s)")
fig.suptitle(f"{leg_to_plot.upper()} leg joint angles")
Text(0.5, 0.98, 'LF leg joint angles')
We are now ready to instantiate a fly model, attach it to a world, and create a simulation object. Detailed explanations are available in Tutorial 1: Composing models and scenes.
from flygym.compose import Fly, KinematicPosePreset, ActuatorType
from flygym.anatomy import Skeleton, AxisOrder, JointPreset, ActuatedDOFPreset
axis_order = AxisOrder.YAW_PITCH_ROLL
articulated_joints = JointPreset.LEGS_ONLY
actuated_dofs = ActuatedDOFPreset.LEGS_ACTIVE_ONLY
neutral_pose = KinematicPosePreset.NEUTRAL
actuator_type = ActuatorType.POSITION
# This controlls how strongly the actuators try to track the target joint angles
actuator_gain = 150.0 # in uN*mm/rad (torque applied per angular discrepancy)
fly = Fly()
skeleton = Skeleton(axis_order=axis_order, joint_preset=articulated_joints)
fly.add_joints(skeleton, neutral_pose=neutral_pose)
actuated_dofs_list = fly.skeleton.get_actuated_dofs_from_preset(actuated_dofs)
fly.add_actuators(
actuated_dofs_list,
actuator_type=actuator_type,
kp=actuator_gain,
neutral_input=neutral_pose,
)
fly.colorize()
tracking_cam = fly.add_tracking_camera()
The only thing worth noting here that was not mentioned in the previous tutorials is that insects have specialized structures on thir legs that allow them to adhere to surfaces. We can use fly.add_leg_adhesion to emulate this. However, note that adding adhesion in the model does not do anything by itself—the adhesion actuators must be switched on (or off) during simulation time, as demonstrated in a later code block.
fly.add_leg_adhesion()
{'lf': MJCF Element: <adhesion name="lf_tarsus5-adhesion" class="/" ctrlrange="1 100" body="lf_tarsus5" gain="1"/>,
'lm': MJCF Element: <adhesion name="lm_tarsus5-adhesion" class="/" ctrlrange="1 100" body="lm_tarsus5" gain="1"/>,
'lh': MJCF Element: <adhesion name="lh_tarsus5-adhesion" class="/" ctrlrange="1 100" body="lh_tarsus5" gain="1"/>,
'rf': MJCF Element: <adhesion name="rf_tarsus5-adhesion" class="/" ctrlrange="1 100" body="rf_tarsus5" gain="1"/>,
'rm': MJCF Element: <adhesion name="rm_tarsus5-adhesion" class="/" ctrlrange="1 100" body="rm_tarsus5" gain="1"/>,
'rh': MJCF Element: <adhesion name="rh_tarsus5-adhesion" class="/" ctrlrange="1 100" body="rh_tarsus5" gain="1"/>}
Now we create a world, attach the fly to it, set up the simulation, and add a renderer:
from flygym.anatomy import JointDOF, RotationAxis, BodySegment
from flygym.compose import FlatGroundWorld
from flygym.utils.math import Rotation3D
from flygym import Simulation
spawn_pos = [0, 0, 0.7] # center of thorax is at 0.7 mm above the ground
spawn_rot = Rotation3D(format="quat", values=[1, 0, 0, 0]) # no rotation
world = FlatGroundWorld()
world.add_fly(fly, spawn_pos, spawn_rot)
sim = Simulation(world)
sim.set_renderer(tracking_cam)
<flygym.rendering.Renderer at 0x712eb4246ff0>
Here we call MotionSnippet.get_joint_angles to produce the actuator target sequence. Internally this does three things:
- The Spotlight system records data at 330 Hz, which is much slower than NeuroMechFly simulation (typically 10 kHz). The recorded joint angles are therefore upsampled to the target frequency. More precisely, a 3rd-order Savitzky-Golay filter with a very small time window (0.03 s) is first applied to the time series; the finer time series is then fitted using cubic-spline interpolation.
- The data is reshaped into a NumPy array of shape (n_steps, n_actuated_dofs), with the DoFs reordered to match the order expected by FlyGym.
- FlyGym (after v2.0.0) defines joint angles anatomically—that is, "outward" rotations have the same sign on both left- and right-side legs even though they have opposite directions of rotation relative to the world frame. This is not the case in SeqIKPy, so this function flips the sign for roll and yaw DoFs on the right-side legs (pitch is not affected).
sim_timestep = 1e-4
joint_angles_nmf = snippet.get_joint_angles(
output_timestep=sim_timestep,
output_dof_order=fly.get_actuated_jointdofs_order(actuator_type),
)
sim_duration_sec = snippet.joint_angles.shape[0] / snippet.data_fps
nmfsim_timegrid = np.arange(0, sim_duration_sec, sim_timestep)
print("Number of steps to simulate:", nmfsim_timegrid.size)
print("Shape of target joint angles:", joint_angles_nmf.shape)
Number of steps to simulate: 20000 Shape of target joint angles: (20000, 42)
Now we implement the main simulation loop. At each timestep we feed the precomputed target angles to the position actuators, step the physics forward, and record the resulting joint angles and actuator torques. We also call the renderer at each step; the FlyGym renderer internally determines whether a frame should actually be rendered based on the user-specified output playback speed and output frame rate (in this case we use the default parameters when we called sim.set_renderer).
Note that leg adhesion is turned on before the loop and kept on throughout.
from tqdm import trange
fly_name = fly.name
nsteps_sim = nmfsim_timegrid.size
n_dofs = len(fly.get_jointdofs_order())
n_actuated_dofs = len(fly.get_actuated_jointdofs_order(actuator_type))
simulated_joint_angles = np.full((nsteps_sim, n_dofs), np.nan, dtype=np.float32)
actuator_torques = np.full((nsteps_sim, n_actuated_dofs), np.nan, dtype=np.float32)
sim.reset()
# Turn adhesion on for all 6 legs
sim.set_leg_adhesion_states(fly_name, np.ones(6, dtype=np.bool))
# sim.warmup()
for step_idx in trange(nsteps_sim, desc="Simulating"):
# Set actuator inputs
target_angles = joint_angles_nmf[step_idx, :]
sim.set_actuator_inputs(fly_name, actuator_type, target_angles)
# Step simulation
sim.step_with_profile() # can be replaced with sim.step()
# Record simulation data
simulated_joint_angles[step_idx, :] = sim.get_joint_angles(fly_name)
actuator_torques[step_idx, :] = sim.get_actuator_forces(fly_name, actuator_type)
# Render as needed (flygym internally decides whether to actually render a frame
# based on renderer configs)
sim.render_as_needed_with_profile() # can be replaced with sim.render_as_needed()
Simulating: 100%|██████████| 20000/20000 [00:01<00:00, 13032.12it/s]
Observe that we have called sim.step_with_profile() and sim.render_as_needed_with_profile to time code execution behind the scenes. This allows us to use sim.print_performance_report() to see a breakdown of compute time shown below. In practice, you can replace these with sim.step() and sim.render_as_needed(), which ironically should be a bit faster because we remove the timing overhead.
sim.print_performance_report()
PERFORMANCE PROFILE
| Stage | Time/step (us) | Percent (%) | Throughput (iters/s) | Throughput x realtime |
|---|---|---|---|---|
| Physics simulation advancement | 61 | 89 | 16446 | 1.64 |
| Rendering* | 7 | 11 | 133854 | 13.39 |
| TOTAL | 68 | 100 | 14647 | 1.46 |
* Note: 248 frames were rendered out of 20000 steps. Therefore, rendering time per image is 602 us.
Let's play the rendered video in the notebook:
sim.renderer.show_in_notebook()
nmf/trackcam |
Finally, we compare the actuator target angles (blue) against the angles actually achieved by the simulator (orange):
fig, axes = plt.subplots(7, 1, figsize=(9, 6.5), tight_layout=True, sharex=True)
for dof_idx, (parent_link, child_link, axis) in enumerate(snippet.dofs_per_leg):
ax = axes[dof_idx]
if parent_link == "thorax":
parent_name = "c_thorax"
else:
parent_name = f"{leg_to_plot}_{parent_link}"
child_name = f"{leg_to_plot}_{child_link}"
target_dof = JointDOF(BodySegment(parent_name), BodySegment(child_name), RotationAxis(axis))
# Index into joint_angles_nmf (actuated DOF order)
actuated_dof_index = fly.get_actuated_jointdofs_order(actuator_type).index(target_dof)
ts_target = np.rad2deg(joint_angles_nmf[:, actuated_dof_index])
ax.plot(nmfsim_timegrid, ts_target, label="Target", linestyle=":", color="C0")
# Index into simulated_joint_angles (all-DOF order)
nmf_dof_index = fly.get_jointdofs_order().index(target_dof)
ts_sim = np.rad2deg(simulated_joint_angles[:, nmf_dof_index])
ax.plot(nmfsim_timegrid, ts_sim, label="Simulated", color="C1")
_center = np.percentile(ts_target, [10, 90]).mean()
_plotspan = 120 # deg
ax.set_ylabel("Angle\n(°)")
ax.set_title(f"{parent_link}-{child_link} {axis}", fontsize="medium")
if dof_idx == 6:
ax.set_xlabel("Time (s)")
if dof_idx == 0:
ax.legend(bbox_to_anchor=(1.04, 0.5), loc="center left")