Skip to content

rendering

WarpCPURenderer

Bases: _BaseWarpRenderer

CPU-side renderer for multi-world MJWarp simulation.

Source code in src/flygym/warp/rendering.py
class WarpCPURenderer(_BaseWarpRenderer):
    """CPU-side renderer for multi-world MJWarp simulation."""

    def _render_setup_impl(self, **kwargs: Any) -> None:
        self._mj_data_buffer = mj.MjData(self.mj_model)
        # Nothing else to do - just use mjRenderer inherited from CPU Renderer

    def _render_impl(self, mjw_data: mjw.Data) -> bool:
        rendered_images = np.zeros((*self._buf_dim_per_frame, 3), dtype=np.uint8)

        for world_id in self.world_ids:
            wid_among_rendered = self.world_ids.index(world_id)

            # Copy data into CPU MjData struct
            mj.mj_resetData(self.mj_model, self._mj_data_buffer)
            mjw.get_data_into(self._mj_data_buffer, self.mj_model, mjw_data, world_id)

            # Render each enabled camera and store frames
            for cam_name, internal_cam_id in self._cameras_names2id.items():
                cid_among_rendered = self.enabled_cam_names.index(cam_name)

                self.mj_renderer.update_scene(
                    self._mj_data_buffer, internal_cam_id, self.scene_option
                )
                frame = self.mj_renderer.render()

                if self.buffer_frames:
                    rendered_images[wid_among_rendered, cid_among_rendered] = frame

        return rendered_images

    def _fetch_frames_to_cpu_impl(
        self, world_id_among_rendered: int, cam_id_among_rendered: int
    ) -> list[np.ndarray]:
        frames = []
        for rendered_images in self._frames:
            frame = rendered_images[world_id_among_rendered, cam_id_among_rendered, ...]
            frames.append(frame)
        return frames

WarpGPUBatchRenderer

Bases: _BaseWarpRenderer

GPU-side renderer using MJWarp's GPU batch rendering functionality.

Source code in src/flygym/warp/rendering.py
class WarpGPUBatchRenderer(_BaseWarpRenderer):
    """GPU-side renderer using MJWarp's GPU batch rendering functionality."""

    def _render_setup_impl(self, **kwargs: Any) -> None:
        if not self._is_scene_option_default(self.scene_option):
            raise RuntimeError(
                "Custom scene options are not supported with WarpGPUBatchRenderer "
                "because it is not implemented in MJWarp batch rendering."
            )

        self._world_ids_gpu = wp.array(self.world_ids, dtype=wp.int32)
        self._enabled_cam_ids_gpu = wp.array(
            [self._cameras_names2id[n] for n in self.enabled_cam_names], dtype=wp.int32
        )
        cam_mask = [
            self._cameras_id2name[cid] in self.enabled_cam_names
            for cid in range(self.mj_model.ncam)
        ]

        # Create batch rendering context
        self._rendering_context = mjw.create_render_context(
            mjm=self.mj_model,
            nworld=self._n_worlds_total,
            cam_active=cam_mask,
            cam_res=self.camera_res[::-1],  # MJWarp expects (W, H); we use (H, W)
            **kwargs,
        )

        # Remove normal MjRenderer inherited from CPU Renderer
        self.scene_option = None
        self.mj_renderer = None

    def _render_impl(self, mjw_data: mjw.Data) -> bool:
        mjw.refit_bvh(self.mjw_model, mjw_data, self._rendering_context)
        mjw.render(self.mjw_model, mjw_data, self._rendering_context)
        rgb_out = wp.zeros(self._buf_dim_per_frame, dtype=wp.vec3f)
        get_rgb_selected_worlds_and_cameras(
            self._rendering_context,
            self._world_ids_gpu,
            self._enabled_cam_ids_gpu,
            rgb_out,
        )
        return rgb_out

    def _fetch_frames_to_cpu_impl(
        self, world_id_among_rendered: int, cam_id_among_rendered: int
    ) -> list[np.ndarray]:
        frames = []
        for frame_buffer in self._frames:
            frame = frame_buffer[world_id_among_rendered, cam_id_among_rendered, :, :]
            frame = (frame * 255.0).numpy().astype(np.uint8)
            frames.append(frame)
        return frames

    @override
    def close(self):
        return  # nothing to do since we are not using a mj.Renderer context

    @staticmethod
    def _is_scene_option_default(scene_option: mj.MjvOption) -> bool:
        default_option = mj.MjvOption()
        mj.mjv_defaultOption(default_option)
        return scene_option == default_option

modify_world_for_batch_rendering(world)

Modify world MJCF model to make it compatible with MJWarp's GPU batch rendering.

This may reduce texture and lighting realism.

Modification happens in place. Returns True if any modifications were made, False otherwise.

Note for developers: Check if anything here can be dropped upon new MJWarp releases.

Source code in src/flygym/warp/rendering.py
def modify_world_for_batch_rendering(world: BaseWorld) -> bool:
    """Modify world MJCF model to make it compatible with MJWarp's GPU batch rendering.

    This may reduce texture and lighting realism.

    Modification happens in place. Returns True if any modifications were made, False
    otherwise.

    Note for developers: Check if anything here can be dropped upon new MJWarp releases.
    """
    is_modified = False

    # Strip textures from fly body materials
    # (rendering textures on complex meshes causes MJWarp memory corruption)
    for material in world.mjcf_root.asset.find_all("material"):
        if material.full_identifier.split("/")[0] not in world.fly_lookup:
            continue  # not a fly body material - leave it alone
        if material.texture is None:
            continue  # material doesn't have texture - nothing to strip
        texture_element = world.mjcf_root.asset.find(
            "texture", material.texture.full_identifier
        )
        primary_color_rgb = texture_element.rgb1
        material.texture = None
        material.rgba[:3] = primary_color_rgb
        is_modified = True

    # Adjust scale of checker materials (e.g., ground): texrepeat needs to be scaled
    # down by 1000x to get the same pattern - unclear why
    for material in world.mjcf_root.asset.find_all("material"):
        if material.texrepeat is not None:
            material.texrepeat = tuple(tr / 1000 for tr in material.texrepeat)
            is_modified = True

    # Add light above each fly explicitly
    for body in world.mjcf_root.find_all("body"):
        if hasattr(body, "name") and body.name == "c_thorax":
            warnings.warn(f"Adding overhead light for body {body.full_identifier}")
            body.add(
                "light",
                name=body.full_identifier.replace("/", "-") + "-overheadlight",
                mode="track",
                target="c_thorax",
                pos=(0, 0, 30),
                dir=(0, 0, -1),
                directional=True,
                ambient=(10, 10, 10),
                diffuse=(10, 10, 10),
                specular=(0.3, 0.3, 0.3),
            )
            is_modified = True

    return is_modified