Source code for mujoco_warp._src.render

# Copyright 2026 The Newton Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from typing import Tuple

import warp as wp

from mujoco_warp._src import math
from mujoco_warp._src.ray import ray_box
from mujoco_warp._src.ray import ray_capsule
from mujoco_warp._src.ray import ray_cylinder
from mujoco_warp._src.ray import ray_ellipsoid
from mujoco_warp._src.ray import ray_flex_with_bvh
from mujoco_warp._src.ray import ray_mesh_with_bvh
from mujoco_warp._src.ray import ray_mesh_with_bvh_anyhit
from mujoco_warp._src.ray import ray_plane
from mujoco_warp._src.ray import ray_sphere
from mujoco_warp._src.render_util import compute_ray
from mujoco_warp._src.render_util import pack_rgba_to_uint32
from mujoco_warp._src.types import MJ_MAXVAL
from mujoco_warp._src.types import Data
from mujoco_warp._src.types import GeomType
from mujoco_warp._src.types import Model
from mujoco_warp._src.types import RenderContext
from mujoco_warp._src.warp_util import event_scope

wp.set_module_options({"enable_backward": False})


# TODO(team): remove after mjwarp depends on warp-lang >= 1.12 in pyproject.toml
from mujoco_warp._src.types import TEXTURE_DTYPE


@wp.func
def sample_texture(
  # Model:
  geom_type: wp.array(dtype=int),
  mesh_faceadr: wp.array(dtype=int),
  # In:
  geom_id: int,
  tex_repeat: wp.vec2,
  tex: TEXTURE_DTYPE,
  pos: wp.vec3,
  rot: wp.mat33,
  mesh_facetexcoord: wp.array(dtype=wp.vec3i),
  mesh_texcoord: wp.array(dtype=wp.vec2),
  mesh_texcoord_offsets: wp.array(dtype=int),
  hit_point: wp.vec3,
  bary_u: float,
  bary_v: float,
  f: int,
  mesh_id: int,
) -> wp.vec3:
  uv = wp.vec2(0.0, 0.0)

  if geom_type[geom_id] == GeomType.PLANE:
    local = wp.transpose(rot) @ (hit_point - pos)
    uv = wp.vec2(local[0], local[1])

  if geom_type[geom_id] == GeomType.MESH:
    if f < 0 or mesh_id < 0:
      return wp.vec3(0.0, 0.0, 0.0)

    face_adr = mesh_faceadr[mesh_id] + f
    uv0 = mesh_texcoord[mesh_texcoord_offsets[mesh_id] + mesh_facetexcoord[face_adr][0]]
    uv1 = mesh_texcoord[mesh_texcoord_offsets[mesh_id] + mesh_facetexcoord[face_adr][1]]
    uv2 = mesh_texcoord[mesh_texcoord_offsets[mesh_id] + mesh_facetexcoord[face_adr][2]]
    uv = uv0 * bary_u + uv1 * bary_v + uv2 * (1.0 - bary_u - bary_v)

  u = uv[0] * tex_repeat[0]
  v = uv[1] * tex_repeat[1]
  u = u - wp.floor(u)
  v = v - wp.floor(v)
  tex_color = wp.texture_sample(tex, wp.vec2(u, v), dtype=wp.vec4)
  return wp.vec3(tex_color[0], tex_color[1], tex_color[2])


# TODO: Investigate combining cast_ray and cast_ray_first_hit
@wp.func
def cast_ray(
  # Model:
  geom_type: wp.array(dtype=int),
  geom_dataid: wp.array(dtype=int),
  geom_size: wp.array2d(dtype=wp.vec3),
  # Data in:
  geom_xpos_in: wp.array2d(dtype=wp.vec3),
  geom_xmat_in: wp.array2d(dtype=wp.mat33),
  # In:
  bvh_id: wp.uint64,
  group_root: int,
  world_id: int,
  bvh_ngeom: int,
  enabled_geom_ids: wp.array(dtype=int),
  mesh_bvh_id: wp.array(dtype=wp.uint64),
  hfield_bvh_id: wp.array(dtype=wp.uint64),
  ray_origin_world: wp.vec3,
  ray_dir_world: wp.vec3,
) -> Tuple[int, float, wp.vec3, float, float, int, int]:
  dist = float(MJ_MAXVAL)
  normal = wp.vec3(0.0, 0.0, 0.0)
  geom_id = int(-1)
  bary_u = float(0.0)
  bary_v = float(0.0)
  face_idx = int(-1)
  geom_mesh_id = int(-1)

  query = wp.bvh_query_ray(bvh_id, ray_origin_world, ray_dir_world, group_root)
  bounds_nr = int(0)

  while wp.bvh_query_next(query, bounds_nr, dist):
    gi_global = bounds_nr
    gi_bvh_local = gi_global - (world_id * bvh_ngeom)
    gi = enabled_geom_ids[gi_bvh_local]

    hit_mesh_id = int(-1)
    u = float(0.0)
    v = float(0.0)
    f = int(-1)
    n = wp.vec3(0.0, 0.0, 0.0)

    # TODO: Investigate branch elimination with static loop unrolling
    if geom_type[gi] == GeomType.PLANE:
      d, n = ray_plane(
        geom_xpos_in[world_id, gi],
        geom_xmat_in[world_id, gi],
        geom_size[world_id % geom_size.shape[0], gi],
        ray_origin_world,
        ray_dir_world,
      )
    if geom_type[gi] == GeomType.HFIELD:
      d, n, u, v, f, geom_hfield_id = ray_mesh_with_bvh(
        hfield_bvh_id,
        geom_dataid[gi],
        geom_xpos_in[world_id, gi],
        geom_xmat_in[world_id, gi],
        ray_origin_world,
        ray_dir_world,
        dist,
      )
    if geom_type[gi] == GeomType.SPHERE:
      d, n = ray_sphere(
        geom_xpos_in[world_id, gi],
        geom_size[world_id % geom_size.shape[0], gi][0] * geom_size[world_id % geom_size.shape[0], gi][0],
        ray_origin_world,
        ray_dir_world,
      )
    if geom_type[gi] == GeomType.ELLIPSOID:
      d, n = ray_ellipsoid(
        geom_xpos_in[world_id, gi],
        geom_xmat_in[world_id, gi],
        geom_size[world_id % geom_size.shape[0], gi],
        ray_origin_world,
        ray_dir_world,
      )
    if geom_type[gi] == GeomType.CAPSULE:
      d, n = ray_capsule(
        geom_xpos_in[world_id, gi],
        geom_xmat_in[world_id, gi],
        geom_size[world_id % geom_size.shape[0], gi],
        ray_origin_world,
        ray_dir_world,
      )
    if geom_type[gi] == GeomType.CYLINDER:
      d, n = ray_cylinder(
        geom_xpos_in[world_id, gi],
        geom_xmat_in[world_id, gi],
        geom_size[world_id % geom_size.shape[0], gi],
        ray_origin_world,
        ray_dir_world,
      )
    if geom_type[gi] == GeomType.BOX:
      d, all, n = ray_box(
        geom_xpos_in[world_id, gi],
        geom_xmat_in[world_id, gi],
        geom_size[world_id % geom_size.shape[0], gi],
        ray_origin_world,
        ray_dir_world,
      )
    if geom_type[gi] == GeomType.MESH:
      d, n, u, v, f, hit_mesh_id = ray_mesh_with_bvh(
        mesh_bvh_id,
        geom_dataid[gi],
        geom_xpos_in[world_id, gi],
        geom_xmat_in[world_id, gi],
        ray_origin_world,
        ray_dir_world,
        dist,
      )

    if d >= 0.0 and d < dist:
      dist = d
      normal = n
      geom_id = gi
      bary_u = u
      bary_v = v
      face_idx = f
      geom_mesh_id = hit_mesh_id

  return geom_id, dist, normal, bary_u, bary_v, face_idx, geom_mesh_id


@wp.func
def cast_ray_first_hit(
  # Model:
  geom_type: wp.array(dtype=int),
  geom_dataid: wp.array(dtype=int),
  geom_size: wp.array2d(dtype=wp.vec3),
  # Data in:
  geom_xpos_in: wp.array2d(dtype=wp.vec3),
  geom_xmat_in: wp.array2d(dtype=wp.mat33),
  # In:
  bvh_id: wp.uint64,
  group_root: int,
  world_id: int,
  bvh_ngeom: int,
  enabled_geom_ids: wp.array(dtype=int),
  mesh_bvh_id: wp.array(dtype=wp.uint64),
  hfield_bvh_id: wp.array(dtype=wp.uint64),
  ray_origin_world: wp.vec3,
  ray_dir_world: wp.vec3,
  max_dist: float,
) -> bool:
  """A simpler version of casting rays that only checks for the first hit."""
  query = wp.bvh_query_ray(bvh_id, ray_origin_world, ray_dir_world, group_root)
  bounds_nr = int(0)

  while wp.bvh_query_next(query, bounds_nr, max_dist):
    gi_global = bounds_nr
    gi_bvh_local = gi_global - (world_id * bvh_ngeom)
    gi = enabled_geom_ids[gi_bvh_local]

    # TODO: Investigate branch elimination with static loop unrolling
    if geom_type[gi] == GeomType.PLANE:
      d, n = ray_plane(
        geom_xpos_in[world_id, gi],
        geom_xmat_in[world_id, gi],
        geom_size[world_id % geom_size.shape[0], gi],
        ray_origin_world,
        ray_dir_world,
      )
    if geom_type[gi] == GeomType.HFIELD:
      d, n, u, v, f, geom_hfield_id = ray_mesh_with_bvh(
        hfield_bvh_id,
        geom_dataid[gi],
        geom_xpos_in[world_id, gi],
        geom_xmat_in[world_id, gi],
        ray_origin_world,
        ray_dir_world,
        max_dist,
      )
    if geom_type[gi] == GeomType.SPHERE:
      d, n = ray_sphere(
        geom_xpos_in[world_id, gi],
        geom_size[world_id % geom_size.shape[0], gi][0] * geom_size[world_id % geom_size.shape[0], gi][0],
        ray_origin_world,
        ray_dir_world,
      )
    if geom_type[gi] == GeomType.ELLIPSOID:
      d, n = ray_ellipsoid(
        geom_xpos_in[world_id, gi],
        geom_xmat_in[world_id, gi],
        geom_size[world_id % geom_size.shape[0], gi],
        ray_origin_world,
        ray_dir_world,
      )
    if geom_type[gi] == GeomType.CAPSULE:
      d, n = ray_capsule(
        geom_xpos_in[world_id, gi],
        geom_xmat_in[world_id, gi],
        geom_size[world_id % geom_size.shape[0], gi],
        ray_origin_world,
        ray_dir_world,
      )
    if geom_type[gi] == GeomType.CYLINDER:
      d, n = ray_cylinder(
        geom_xpos_in[world_id, gi],
        geom_xmat_in[world_id, gi],
        geom_size[world_id % geom_size.shape[0], gi],
        ray_origin_world,
        ray_dir_world,
      )
    if geom_type[gi] == GeomType.BOX:
      d, all, n = ray_box(
        geom_xpos_in[world_id, gi],
        geom_xmat_in[world_id, gi],
        geom_size[world_id % geom_size.shape[0], gi],
        ray_origin_world,
        ray_dir_world,
      )
    if geom_type[gi] == GeomType.MESH:
      hit = ray_mesh_with_bvh_anyhit(
        mesh_bvh_id,
        geom_dataid[gi],
        geom_xpos_in[world_id, gi],
        geom_xmat_in[world_id, gi],
        ray_origin_world,
        ray_dir_world,
        max_dist,
      )
      d = 0.0 if hit else -1.0

    if d >= 0.0 and d < max_dist:
      return True

  return False


@wp.func
def compute_lighting(
  # Model:
  geom_type: wp.array(dtype=int),
  geom_dataid: wp.array(dtype=int),
  geom_size: wp.array2d(dtype=wp.vec3),
  # Data in:
  geom_xpos_in: wp.array2d(dtype=wp.vec3),
  geom_xmat_in: wp.array2d(dtype=wp.mat33),
  # In:
  use_shadows: bool,
  bvh_id: wp.uint64,
  group_root: int,
  bvh_ngeom: int,
  enabled_geom_ids: wp.array(dtype=int),
  world_id: int,
  mesh_bvh_id: wp.array(dtype=wp.uint64),
  hfield_bvh_id: wp.array(dtype=wp.uint64),
  lightactive: bool,
  lighttype: int,
  lightcastshadow: bool,
  lightpos: wp.vec3,
  lightdir: wp.vec3,
  normal: wp.vec3,
  hitpoint: wp.vec3,
) -> float:
  light_contribution = float(0.0)

  # TODO: We should probably only be looping over active lights
  # in the first place with a static loop of enabled light idx?
  if not lightactive:
    return light_contribution

  L = wp.vec3(0.0, 0.0, 0.0)
  dist_to_light = float(MJ_MAXVAL)
  attenuation = float(1.0)

  if lighttype == 1:  # directional light
    L = wp.normalize(-lightdir)
  else:
    L, dist_to_light = math.normalize_with_norm(lightpos - hitpoint)
    attenuation = 1.0 / (1.0 + 0.02 * dist_to_light * dist_to_light)
    if lighttype == 0:  # spot light
      spot_dir = wp.normalize(lightdir)
      cos_theta = wp.dot(-L, spot_dir)
      spot_factor = wp.min(1.0, wp.max(0.0, (cos_theta - 0.85) / (0.95 - 0.85)))
      attenuation = attenuation * spot_factor

  ndotl = wp.max(0.0, wp.dot(normal, L))
  if ndotl == 0.0:
    return light_contribution

  visible = float(1.0)

  if use_shadows and lightcastshadow:
    # Nudge the origin slightly along the surface normal to avoid
    # self-intersection when casting shadow rays
    eps = 1.0e-4
    shadow_origin = hitpoint + normal * eps
    # Distance-limited shadows: cap by dist_to_light (for non-directional)
    max_t = float(dist_to_light - 1.0e-3)
    if lighttype == 1:  # directional light
      max_t = float(1.0e8)

    shadow_hit = cast_ray_first_hit(
      geom_type,
      geom_dataid,
      geom_size,
      geom_xpos_in,
      geom_xmat_in,
      bvh_id,
      group_root,
      world_id,
      bvh_ngeom,
      enabled_geom_ids,
      mesh_bvh_id,
      hfield_bvh_id,
      shadow_origin,
      L,
      max_t,
    )

    if shadow_hit:
      visible = 0.3

  return ndotl * attenuation * visible


[docs] @event_scope def render(m: Model, d: Data, rc: RenderContext): """Render the current frame. Outputs are stored in buffers within the render context. Args: m: The model on device. d: The data on device. rc: The render context on device. """ rc.rgb_data.fill_(rc.background_color) rc.depth_data.fill_(0.0) @wp.kernel(module="unique", enable_backward=False) def _render_megakernel( # Model: geom_type: wp.array(dtype=int), geom_dataid: wp.array(dtype=int), geom_matid: wp.array2d(dtype=int), geom_size: wp.array2d(dtype=wp.vec3), geom_rgba: wp.array2d(dtype=wp.vec4), cam_projection: wp.array(dtype=int), cam_fovy: wp.array2d(dtype=float), cam_sensorsize: wp.array(dtype=wp.vec2), cam_intrinsic: wp.array2d(dtype=wp.vec4), light_type: wp.array2d(dtype=int), light_castshadow: wp.array2d(dtype=bool), light_active: wp.array2d(dtype=bool), mesh_faceadr: wp.array(dtype=int), mat_texid: wp.array3d(dtype=int), mat_texrepeat: wp.array2d(dtype=wp.vec2), mat_rgba: wp.array2d(dtype=wp.vec4), # Data in: geom_xpos_in: wp.array2d(dtype=wp.vec3), geom_xmat_in: wp.array2d(dtype=wp.mat33), cam_xpos_in: wp.array2d(dtype=wp.vec3), cam_xmat_in: wp.array2d(dtype=wp.mat33), light_xpos_in: wp.array2d(dtype=wp.vec3), light_xdir_in: wp.array2d(dtype=wp.vec3), # In: nrender: int, use_shadows: bool, bvh_ngeom: int, cam_res: wp.array(dtype=wp.vec2i), cam_id_map: wp.array(dtype=int), ray: wp.array(dtype=wp.vec3), rgb_adr: wp.array(dtype=int), depth_adr: wp.array(dtype=int), render_rgb: wp.array(dtype=bool), render_depth: wp.array(dtype=bool), bvh_id: wp.uint64, group_root: wp.array(dtype=int), flex_bvh_id: wp.uint64, flex_group_root: wp.array(dtype=int), enabled_geom_ids: wp.array(dtype=int), mesh_bvh_id: wp.array(dtype=wp.uint64), mesh_facetexcoord: wp.array(dtype=wp.vec3i), mesh_texcoord: wp.array(dtype=wp.vec2), mesh_texcoord_offsets: wp.array(dtype=int), hfield_bvh_id: wp.array(dtype=wp.uint64), flex_rgba: wp.array(dtype=wp.vec4), # TODO: remove after mjwarp depends on warp-lang >= 1.12 in pyproject.toml textures: wp.array(dtype=TEXTURE_DTYPE), # Out: rgb_out: wp.array2d(dtype=wp.uint32), depth_out: wp.array2d(dtype=float), ): world_idx, ray_idx = wp.tid() # Map global ray_idx -> (cam_idx, ray_idx_local) using cumulative sizes cam_idx = int(-1) ray_idx_local = int(-1) accum = int(0) for i in range(nrender): num_i = cam_res[i][0] * cam_res[i][1] if ray_idx < accum + num_i: cam_idx = i ray_idx_local = ray_idx - accum break accum += num_i if cam_idx == -1 or ray_idx_local < 0: return if not render_rgb[cam_idx] and not render_depth[cam_idx]: return # Map active camera index to MuJoCo camera ID mujoco_cam_id = cam_id_map[cam_idx] if wp.static(rc.use_precomputed_rays): ray_dir_local_cam = ray[ray_idx] else: img_w = cam_res[cam_idx][0] img_h = cam_res[cam_idx][1] px = ray_idx_local % img_w py = ray_idx_local // img_w ray_dir_local_cam = compute_ray( cam_projection[mujoco_cam_id], cam_fovy[world_idx % cam_fovy.shape[0], mujoco_cam_id], cam_sensorsize[mujoco_cam_id], cam_intrinsic[world_idx % cam_intrinsic.shape[0], mujoco_cam_id], img_w, img_h, px, py, wp.static(rc.znear), ) ray_dir_world = cam_xmat_in[world_idx, mujoco_cam_id] @ ray_dir_local_cam ray_origin_world = cam_xpos_in[world_idx, mujoco_cam_id] geom_id, dist, normal, u, v, f, mesh_id = cast_ray( geom_type, geom_dataid, geom_size, geom_xpos_in, geom_xmat_in, bvh_id, group_root[world_idx], world_idx, bvh_ngeom, enabled_geom_ids, mesh_bvh_id, hfield_bvh_id, ray_origin_world, ray_dir_world, ) if wp.static(m.nflex > 0): d, n, u, v, f = ray_flex_with_bvh( flex_bvh_id, flex_group_root[world_idx], ray_origin_world, ray_dir_world, dist, ) if d >= 0.0 and d < dist: dist = d normal = n geom_id = -2 # Early Out if geom_id == -1: return if render_depth[cam_idx]: depth_out[world_idx, depth_adr[cam_idx] + ray_idx_local] = dist if not render_rgb[cam_idx]: return # Shade the pixel hit_point = ray_origin_world + ray_dir_world * dist if geom_id == -2: # TODO: Currently flex textures are not supported, and only the first rgba value # is used until further flex support is added. color = flex_rgba[0] elif geom_matid[world_idx % geom_matid.shape[0], geom_id] == -1: color = geom_rgba[world_idx % geom_rgba.shape[0], geom_id] else: color = mat_rgba[world_idx % mat_rgba.shape[0], geom_matid[world_idx % geom_matid.shape[0], geom_id]] base_color = wp.vec3(color[0], color[1], color[2]) hit_color = base_color if wp.static(rc.use_textures): if geom_id != -2: mat_id = geom_matid[world_idx % geom_matid.shape[0], geom_id] if mat_id >= 0: tex_id = mat_texid[world_idx % mat_texid.shape[0], mat_id, 1] if tex_id >= 0: tex_color = sample_texture( geom_type, mesh_faceadr, geom_id, mat_texrepeat[world_idx % mat_texrepeat.shape[0], mat_id], textures[tex_id], geom_xpos_in[world_idx, geom_id], geom_xmat_in[world_idx, geom_id], mesh_facetexcoord, mesh_texcoord, mesh_texcoord_offsets, hit_point, u, v, f, mesh_id, ) base_color = wp.cw_mul(base_color, tex_color) len_n = wp.length(normal) n = normal if len_n > 0.0 else wp.vec3(0.0, 0.0, 1.0) n = wp.normalize(n) hemispheric = 0.5 * (n[2] + 1.0) ambient_color = wp.vec3(0.4, 0.4, 0.45) * hemispheric + wp.vec3(0.1, 0.1, 0.12) * (1.0 - hemispheric) result = 0.5 * wp.cw_mul(base_color, ambient_color) # Apply lighting and shadows for l in range(wp.static(m.nlight)): light_contribution = compute_lighting( geom_type, geom_dataid, geom_size, geom_xpos_in, geom_xmat_in, use_shadows, bvh_id, group_root[world_idx], bvh_ngeom, enabled_geom_ids, world_idx, mesh_bvh_id, hfield_bvh_id, light_active[world_idx % light_active.shape[0], l], light_type[world_idx % light_type.shape[0], l], light_castshadow[world_idx % light_castshadow.shape[0], l], light_xpos_in[world_idx, l], light_xdir_in[world_idx, l], normal, hit_point, ) result = result + base_color * light_contribution hit_color = wp.min(result, wp.vec3(1.0, 1.0, 1.0)) hit_color = wp.max(hit_color, wp.vec3(0.0, 0.0, 0.0)) rgb_out[world_idx, rgb_adr[cam_idx] + ray_idx_local] = pack_rgba_to_uint32( hit_color[0] * 255.0, hit_color[1] * 255.0, hit_color[2] * 255.0, 255.0, ) wp.launch( kernel=_render_megakernel, dim=(d.nworld, rc.total_rays), inputs=[ m.geom_type, m.geom_dataid, m.geom_matid, m.geom_size, m.geom_rgba, m.cam_projection, m.cam_fovy, m.cam_sensorsize, m.cam_intrinsic, m.light_type, m.light_castshadow, m.light_active, m.mesh_faceadr, m.mat_texid, m.mat_texrepeat, m.mat_rgba, d.geom_xpos, d.geom_xmat, d.cam_xpos, d.cam_xmat, d.light_xpos, d.light_xdir, rc.nrender, rc.use_shadows, rc.bvh_ngeom, rc.cam_res, rc.cam_id_map, rc.ray, rc.rgb_adr, rc.depth_adr, rc.render_rgb, rc.render_depth, rc.bvh_id, rc.group_root, rc.flex_bvh_id, rc.flex_group_root, rc.enabled_geom_ids, rc.mesh_bvh_id, rc.mesh_facetexcoord, rc.mesh_texcoord, rc.mesh_texcoord_offsets, rc.hfield_bvh_id, rc.flex_rgba, rc.textures, ], outputs=[ rc.rgb_data, rc.depth_data, ], )