Source code for mujoco_warp._src.bvh

# Copyright 2026 The Newton Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from __future__ import annotations

from typing import Tuple

import mujoco
import numpy as np
import warp as wp

from mujoco_warp._src.types import MJ_MAXVAL
from mujoco_warp._src.types import Data
from mujoco_warp._src.types import GeomType
from mujoco_warp._src.types import Model
from mujoco_warp._src.types import RenderContext
from mujoco_warp._src.warp_util import event_scope

wp.set_module_options({"enable_backward": False})


[docs] @event_scope def refit_bvh(m: Model, d: Data, rc: RenderContext): """Refit the dynamic BVH structures in the render context.""" refit_scene_bvh(m, d, rc) if m.nflex: refit_flex_bvh(m, d, rc)
@wp.func def _compute_box_bounds( # In: pos: wp.vec3, rot: wp.mat33, size: wp.vec3, ) -> Tuple[wp.vec3, wp.vec3]: min_bound = wp.vec3(MJ_MAXVAL, MJ_MAXVAL, MJ_MAXVAL) max_bound = wp.vec3(-MJ_MAXVAL, -MJ_MAXVAL, -MJ_MAXVAL) for i in range(2): for j in range(2): for k in range(2): local_corner = wp.vec3( size[0] * (2.0 * float(i) - 1.0), size[1] * (2.0 * float(j) - 1.0), size[2] * (2.0 * float(k) - 1.0), ) world_corner = pos + rot @ local_corner min_bound = wp.min(min_bound, world_corner) max_bound = wp.max(max_bound, world_corner) return min_bound, max_bound @wp.func def _compute_sphere_bounds( # In: pos: wp.vec3, rot: wp.mat33, size: wp.vec3, ) -> Tuple[wp.vec3, wp.vec3]: radius = size[0] return pos - wp.vec3(radius, radius, radius), pos + wp.vec3(radius, radius, radius) @wp.func def _compute_capsule_bounds( # In: pos: wp.vec3, rot: wp.mat33, size: wp.vec3, ) -> Tuple[wp.vec3, wp.vec3]: radius = size[0] half_length = size[1] z = wp.vec3(rot[0, 2], rot[1, 2], rot[2, 2]) world_end1 = pos - z * half_length world_end2 = pos + z * half_length seg_min = wp.min(world_end1, world_end2) seg_max = wp.max(world_end1, world_end2) inflate = wp.vec3(radius, radius, radius) return seg_min - inflate, seg_max + inflate @wp.func def _compute_plane_bounds( # In: pos: wp.vec3, rot: wp.mat33, size: wp.vec3, ) -> Tuple[wp.vec3, wp.vec3]: # If plane size is non-positive, treat as infinite plane and use a large default extent size_scale = wp.max(size[0], size[1]) * 2.0 if size[0] <= 0.0 or size[1] <= 0.0: size_scale = 1000.0 min_bound = wp.vec3(MJ_MAXVAL, MJ_MAXVAL, MJ_MAXVAL) max_bound = wp.vec3(-MJ_MAXVAL, -MJ_MAXVAL, -MJ_MAXVAL) for i in range(2): for j in range(2): local_corner = wp.vec3( size_scale * (2.0 * float(i) - 1.0), size_scale * (2.0 * float(j) - 1.0), 0.0, ) world_corner = pos + rot @ local_corner min_bound = wp.min(min_bound, world_corner) max_bound = wp.max(max_bound, world_corner) min_bound = min_bound - wp.vec3(0.01, 0.01, 0.01) max_bound = max_bound + wp.vec3(0.01, 0.01, 0.01) return min_bound, max_bound @wp.func def _compute_ellipsoid_bounds( # In: pos: wp.vec3, rot: wp.mat33, size: wp.vec3, ) -> Tuple[wp.vec3, wp.vec3]: # Half-extent along each world axis equals the norm of the corresponding row of rot*diag(size) row0 = wp.vec3(rot[0, 0] * size[0], rot[0, 1] * size[1], rot[0, 2] * size[2]) row1 = wp.vec3(rot[1, 0] * size[0], rot[1, 1] * size[1], rot[1, 2] * size[2]) row2 = wp.vec3(rot[2, 0] * size[0], rot[2, 1] * size[1], rot[2, 2] * size[2]) extent = wp.vec3(wp.length(row0), wp.length(row1), wp.length(row2)) return pos - extent, pos + extent @wp.func def _compute_cylinder_bounds( # In: pos: wp.vec3, rot: wp.mat33, size: wp.vec3, ) -> Tuple[wp.vec3, wp.vec3]: radius = size[0] half_height = size[1] axis = wp.vec3(rot[0, 2], rot[1, 2], rot[2, 2]) axis_abs = wp.vec3(wp.abs(axis[0]), wp.abs(axis[1]), wp.abs(axis[2])) basis_x = wp.vec3(rot[0, 0], rot[1, 0], rot[2, 0]) basis_y = wp.vec3(rot[0, 1], rot[1, 1], rot[2, 1]) radial_x = radius * wp.sqrt(basis_x[0] * basis_x[0] + basis_y[0] * basis_y[0]) radial_y = radius * wp.sqrt(basis_x[1] * basis_x[1] + basis_y[1] * basis_y[1]) radial_z = radius * wp.sqrt(basis_x[2] * basis_x[2] + basis_y[2] * basis_y[2]) extent = wp.vec3( radial_x + half_height * axis_abs[0], radial_y + half_height * axis_abs[1], radial_z + half_height * axis_abs[2], ) return pos - extent, pos + extent @wp.kernel def _compute_bvh_bounds( # Model: geom_type: wp.array(dtype=int), geom_dataid: wp.array(dtype=int), geom_size: wp.array2d(dtype=wp.vec3), # Data in: geom_xpos_in: wp.array2d(dtype=wp.vec3), geom_xmat_in: wp.array2d(dtype=wp.mat33), # In: bvh_ngeom: int, enabled_geom_ids: wp.array(dtype=int), mesh_bounds_size: wp.array(dtype=wp.vec3), hfield_bounds_size: wp.array(dtype=wp.vec3), # Out: lower_out: wp.array(dtype=wp.vec3), upper_out: wp.array(dtype=wp.vec3), group_out: wp.array(dtype=int), ): world_id, geom_local_id = wp.tid() geom_id = enabled_geom_ids[geom_local_id] pos = geom_xpos_in[world_id, geom_id] rot = geom_xmat_in[world_id, geom_id] size = geom_size[world_id % geom_size.shape[0], geom_id] type = geom_type[geom_id] # TODO: Investigate branch elimination with static loop unrolling if type == GeomType.SPHERE: lower_bound, upper_bound = _compute_sphere_bounds(pos, rot, size) elif type == GeomType.CAPSULE: lower_bound, upper_bound = _compute_capsule_bounds(pos, rot, size) elif type == GeomType.PLANE: lower_bound, upper_bound = _compute_plane_bounds(pos, rot, size) elif type == GeomType.MESH: size = mesh_bounds_size[geom_dataid[geom_id]] lower_bound, upper_bound = _compute_box_bounds(pos, rot, size) elif type == GeomType.ELLIPSOID: lower_bound, upper_bound = _compute_ellipsoid_bounds(pos, rot, size) elif type == GeomType.CYLINDER: lower_bound, upper_bound = _compute_cylinder_bounds(pos, rot, size) elif type == GeomType.BOX: lower_bound, upper_bound = _compute_box_bounds(pos, rot, size) elif type == GeomType.HFIELD: size = hfield_bounds_size[geom_dataid[geom_id]] lower_bound, upper_bound = _compute_box_bounds(pos, rot, size) lower_out[world_id * bvh_ngeom + geom_local_id] = lower_bound upper_out[world_id * bvh_ngeom + geom_local_id] = upper_bound group_out[world_id * bvh_ngeom + geom_local_id] = world_id @wp.kernel def compute_bvh_group_roots( # In: bvh_id: wp.uint64, # Out: group_root_out: wp.array(dtype=int), ): tid = wp.tid() root = wp.bvh_get_group_root(bvh_id, tid) group_root_out[tid] = root def build_scene_bvh(mjm: mujoco.MjModel, mjd: mujoco.MjData, rc: RenderContext, nworld: int): """Build a global BVH for all geometries in all worlds.""" geom_type = wp.array(mjm.geom_type, dtype=int) geom_dataid = wp.array(mjm.geom_dataid, dtype=int) geom_size = wp.array(np.tile(mjm.geom_size[np.newaxis, :, :], (nworld, 1, 1)), dtype=wp.vec3) geom_xpos = wp.array(np.tile(mjd.geom_xpos[np.newaxis, :, :], (nworld, 1, 1)), dtype=wp.vec3) geom_xmat = wp.array(np.tile(mjd.geom_xmat.reshape(mjm.ngeom, 3, 3)[np.newaxis, :, :, :], (nworld, 1, 1, 1)), dtype=wp.mat33) wp.launch( kernel=_compute_bvh_bounds, dim=(nworld, rc.bvh_ngeom), inputs=[ geom_type, geom_dataid, geom_size, geom_xpos, geom_xmat, rc.bvh_ngeom, rc.enabled_geom_ids, rc.mesh_bounds_size, rc.hfield_bounds_size, rc.lower, rc.upper, rc.group, ], ) bvh = wp.Bvh(rc.lower, rc.upper, groups=rc.group, constructor="sah") # BVH handle must be stored to avoid garbage collection rc.bvh = bvh rc.bvh_id = bvh.id wp.launch( kernel=compute_bvh_group_roots, dim=nworld, inputs=[bvh.id], outputs=[rc.group_root], ) def refit_scene_bvh(m: Model, d: Data, rc: RenderContext): wp.launch( kernel=_compute_bvh_bounds, dim=(d.nworld, rc.bvh_ngeom), inputs=[ m.geom_type, m.geom_dataid, m.geom_size, d.geom_xpos, d.geom_xmat, rc.bvh_ngeom, rc.enabled_geom_ids, rc.mesh_bounds_size, rc.hfield_bounds_size, rc.lower, rc.upper, rc.group, ], ) rc.bvh.refit() def build_mesh_bvh( mjm: mujoco.MjModel, meshid: int, constructor: str = "sah", leaf_size: int = 2, ) -> tuple[wp.Mesh, wp.vec3]: """Create a Warp mesh BVH from mesh data.""" v_start = mjm.mesh_vertadr[meshid] v_end = v_start + mjm.mesh_vertnum[meshid] points = mjm.mesh_vert[v_start:v_end] f_start = mjm.mesh_faceadr[meshid] f_end = mjm.mesh_face.shape[0] if (meshid + 1) >= mjm.mesh_faceadr.shape[0] else mjm.mesh_faceadr[meshid + 1] indices = mjm.mesh_face[f_start:f_end] indices = indices.flatten() pmin = np.min(points, axis=0) pmax = np.max(points, axis=0) half = 0.5 * (pmax - pmin) points = wp.array(points, dtype=wp.vec3) indices = wp.array(indices, dtype=wp.int32) mesh = wp.Mesh(points=points, indices=indices, bvh_constructor=constructor, bvh_leaf_size=leaf_size) return mesh, half def _optimize_hfield_mesh( data: np.ndarray, nr: int, nc: int, sx: float, sy: float, sz_scale: float, width: float, height: float, ) -> tuple[np.ndarray, np.ndarray]: """Greedy meshing for heightfield optimization. Merges coplanar adjacent cells into larger rectangles to reduce triangle and vertex count. """ points_map = {} points_list = [] indices_list = [] def get_point_index(r, c): if (r, c) in points_map: return points_map[(r, c)] # Compute vertex position x = sx * (float(c) / width - 1.0) y = sy * (float(r) / height - 1.0) z = float(data[r, c]) * sz_scale idx = len(points_list) points_list.append([x, y, z]) points_map[(r, c)] = idx return idx visited = np.zeros((nr - 1, nc - 1), dtype=bool) for r in range(nr - 1): for c in range(nc - 1): if visited[r, c]: continue # Check if current cell is planar z00 = data[r, c] z01 = data[r, c + 1] z10 = data[r + 1, c] z11 = data[r + 1, c + 1] # Approx check for planarity: z00 + z11 == z01 + z10 is_planar = abs((z00 + z11) - (z01 + z10)) < 1e-5 if not is_planar: # Must emit single cell (2 triangles) idx00 = get_point_index(r, c) idx01 = get_point_index(r, c + 1) idx10 = get_point_index(r + 1, c) idx11 = get_point_index(r + 1, c + 1) # Tri 1: TL, TR, BR indices_list.extend([idx00, idx01, idx11]) # Tri 2: TL, BR, BL indices_list.extend([idx00, idx11, idx10]) visited[r, c] = True continue # If planar, try to expand slope_x = z01 - z00 slope_y = z10 - z00 w = 1 h = 1 def fits_plane(rr, cc): if rr >= nr - 1 or cc >= nc - 1: return False # Check planarity of the cell itself cz00 = data[rr, cc] cz01 = data[rr, cc + 1] cz10 = data[rr + 1, cc] cz11 = data[rr + 1, cc + 1] if abs((cz00 + cz11) - (cz01 + cz10)) >= 1e-5: return False # Check if it lies on the SAME plane as start cell # Expected z at (rr, cc) z_pred = z00 + (rr - r) * slope_y + (cc - c) * slope_x if abs(cz00 - z_pred) >= 1e-5: return False # Since cell is planar and one corner matches, slopes must match if connected cslope_x = cz01 - cz00 cslope_y = cz10 - cz00 if abs(cslope_x - slope_x) >= 1e-5 or abs(cslope_y - slope_y) >= 1e-5: return False return True # Expand width while c + w < nc - 1 and not visited[r, c + w] and fits_plane(r, c + w): w += 1 # Expand height while r + h < nr - 1: # Check entire row row_ok = True for k in range(w): if visited[r + h, c + k] or not fits_plane(r + h, c + k): row_ok = False break if row_ok: h += 1 else: break # Mark visited visited[r : r + h, c : c + w] = True # Emit large quad idx_tl = get_point_index(r, c) idx_tr = get_point_index(r, c + w) idx_bl = get_point_index(r + h, c) idx_br = get_point_index(r + h, c + w) # Tri 1: TL, TR, BR indices_list.extend([idx_tl, idx_tr, idx_br]) # Tri 2: TL, BR, BL indices_list.extend([idx_tl, idx_br, idx_bl]) return np.array(points_list, dtype=np.float32), np.array(indices_list, dtype=np.int32) def build_hfield_bvh( mjm: mujoco.MjModel, hfieldid: int, constructor: str = "sah", leaf_size: int = 2, ) -> tuple[wp.Mesh, wp.vec3]: """Create a Warp mesh BVH from heightfield data.""" nr = mjm.hfield_nrow[hfieldid] nc = mjm.hfield_ncol[hfieldid] sz = np.asarray(mjm.hfield_size[hfieldid], dtype=np.float32) adr = mjm.hfield_adr[hfieldid] data = mjm.hfield_data[adr : adr + nr * nc].reshape((nr, nc)) width = 0.5 * max(nc - 1, 1) height = 0.5 * max(nr - 1, 1) points, indices = _optimize_hfield_mesh( data, nr, nc, sz[0], sz[1], sz[2], width, height, ) pmin = np.min(points, axis=0) pmax = np.max(points, axis=0) half = 0.5 * (pmax - pmin) points = wp.array(points, dtype=wp.vec3) indices = wp.array(indices, dtype=wp.int32) mesh = wp.Mesh( points=points, indices=indices, bvh_constructor=constructor, bvh_leaf_size=leaf_size, ) return mesh, half @wp.kernel def accumulate_flex_vertex_normals( # Model: flex_elem: wp.array(dtype=int), # Data in: flexvert_xpos_in: wp.array2d(dtype=wp.vec3), # Out: flexvert_norm_out: wp.array2d(dtype=wp.vec3), ): """Accumulate per-vertex normals by summing adjacent face normals.""" worldid, elemid = wp.tid() elem_base = elemid * 3 i0 = flex_elem[elem_base + 0] i1 = flex_elem[elem_base + 1] i2 = flex_elem[elem_base + 2] v0 = flexvert_xpos_in[worldid, i0] v1 = flexvert_xpos_in[worldid, i1] v2 = flexvert_xpos_in[worldid, i2] face_nrm = wp.cross(v1 - v0, v2 - v0) face_nrm = wp.normalize(face_nrm) flexvert_norm_out[worldid, i0] += face_nrm flexvert_norm_out[worldid, i1] += face_nrm flexvert_norm_out[worldid, i2] += face_nrm @wp.kernel def normalize_vertex_normals( # Out: flexvert_norm_out: wp.array2d(dtype=wp.vec3), ): """Normalize accumulated vertex normals.""" worldid, vertid = wp.tid() flexvert_norm_out[worldid, vertid] = wp.normalize(flexvert_norm_out[worldid, vertid]) @wp.kernel def _build_flex_2d_elements( # Model: flex_elem: wp.array(dtype=int), # Data in: flexvert_xpos_in: wp.array2d(dtype=wp.vec3), # In: flexvert_norm_in: wp.array2d(dtype=wp.vec3), elem_adr: int, vert_adr: int, face_offset: int, radius: float, nfaces: int, # Out: face_point_out: wp.array(dtype=wp.vec3), face_index_out: wp.array(dtype=int), group_out: wp.array(dtype=int), ): """Create faces from 2D flex elements (triangles). Two faces (top/bottom) per element, separated by the radius of the flex element. """ worldid, elemid = wp.tid() base = elem_adr + elemid * 3 i0 = vert_adr + flex_elem[base + 0] i1 = vert_adr + flex_elem[base + 1] i2 = vert_adr + flex_elem[base + 2] v0 = flexvert_xpos_in[worldid, i0] v1 = flexvert_xpos_in[worldid, i1] v2 = flexvert_xpos_in[worldid, i2] n0 = flexvert_norm_in[worldid, i0] n1 = flexvert_norm_in[worldid, i1] n2 = flexvert_norm_in[worldid, i2] p0_pos = v0 + radius * n0 p1_pos = v1 + radius * n1 p2_pos = v2 + radius * n2 p0_neg = v0 - radius * n0 p1_neg = v1 - radius * n1 p2_neg = v2 - radius * n2 world_face_offset = worldid * nfaces # First face (top): i0, i1, i2 face_id0 = world_face_offset + face_offset + 2 * elemid base0 = face_id0 * 3 face_point_out[base0 + 0] = p0_pos face_point_out[base0 + 1] = p1_pos face_point_out[base0 + 2] = p2_pos face_index_out[base0 + 0] = base0 + 0 face_index_out[base0 + 1] = base0 + 1 face_index_out[base0 + 2] = base0 + 2 group_out[face_id0] = worldid # Second face (bottom): i0, i2, i1 (opposite winding) face_id1 = world_face_offset + face_offset + 2 * elemid + 1 base1 = face_id1 * 3 face_point_out[base1 + 0] = p0_neg face_point_out[base1 + 1] = p1_neg face_point_out[base1 + 2] = p2_neg face_index_out[base1 + 0] = base1 + 0 face_index_out[base1 + 1] = base1 + 2 face_index_out[base1 + 2] = base1 + 1 group_out[face_id1] = worldid @wp.kernel def _build_flex_2d_sides( # Data in: flexvert_xpos_in: wp.array2d(dtype=wp.vec3), # In: flexvert_norm_in: wp.array2d(dtype=wp.vec3), flex_shell_in: wp.array(dtype=int), shell_adr: int, vert_adr: int, face_offset: int, radius: float, nface: int, # Out: face_point_out: wp.array(dtype=wp.vec3), face_index_out: wp.array(dtype=int), group_out: wp.array(dtype=int), ): """Create side faces from 2D flex shell fragments. For each shell fragment (edge i0 -> i1), we emit two triangles: - one using +radius - one using -radius (i0/i1 swapped) """ worldid, shellid = wp.tid() base = shell_adr + 2 * shellid i0 = vert_adr + flex_shell_in[base + 0] i1 = vert_adr + flex_shell_in[base + 1] v0 = flexvert_xpos_in[worldid, i0] v1 = flexvert_xpos_in[worldid, i1] n0 = flexvert_norm_in[worldid, i0] n1 = flexvert_norm_in[worldid, i1] neg_radius = -radius # First side i0, i1 with +radius face_id0 = worldid * nface + face_offset + 2 * shellid base0 = face_id0 * 3 face_point_out[base0 + 0] = v0 + n0 * radius face_point_out[base0 + 1] = v1 + n1 * neg_radius face_point_out[base0 + 2] = v1 + n1 * radius face_index_out[base0 + 0] = base0 + 0 face_index_out[base0 + 1] = base0 + 1 face_index_out[base0 + 2] = base0 + 2 # Second side i1, i0 with -radius face_id1 = worldid * nface + face_offset + 2 * shellid + 1 base1 = face_id1 * 3 face_point_out[base1 + 0] = v1 + n1 * neg_radius face_point_out[base1 + 1] = v0 + n0 * neg_radius face_point_out[base1 + 2] = v0 + n0 * radius face_index_out[base1 + 0] = base1 + 0 face_index_out[base1 + 1] = base1 + 1 face_index_out[base1 + 2] = base1 + 2 group_out[face_id0] = worldid group_out[face_id1] = worldid @wp.kernel def _build_flex_3d_shells( # Data in: flexvert_xpos_in: wp.array2d(dtype=wp.vec3), # In: flex_shell_in: wp.array(dtype=int), shell_adr: int, vert_adr: int, face_offset: int, nface: int, # Out: face_point_out: wp.array(dtype=wp.vec3), face_index_out: wp.array(dtype=int), group_out: wp.array(dtype=int), ): """Create faces from 3D flex shell fragments (triangles). Each shell fragment contributes a single triangle whose vertices are taken directly from the flex vertex positions (one-sided surface). """ worldid, shellid = wp.tid() base = shell_adr + shellid * 3 i0 = vert_adr + flex_shell_in[base + 0] i1 = vert_adr + flex_shell_in[base + 1] i2 = vert_adr + flex_shell_in[base + 2] face_id = worldid * nface + face_offset + shellid base = face_id * 3 v0 = flexvert_xpos_in[worldid, i0] v1 = flexvert_xpos_in[worldid, i1] v2 = flexvert_xpos_in[worldid, i2] face_point_out[base + 0] = v0 face_point_out[base + 1] = v1 face_point_out[base + 2] = v2 face_index_out[base + 0] = base + 0 face_index_out[base + 1] = base + 1 face_index_out[base + 2] = base + 2 group_out[face_id] = worldid @wp.kernel def _update_flex_face_points( # Model: nflex: int, flex_dim: wp.array(dtype=int), flex_vertadr: wp.array(dtype=int), flex_elemnum: wp.array(dtype=int), flex_elem: wp.array(dtype=int), # Data in: flexvert_xpos_in: wp.array2d(dtype=wp.vec3), # In: flex_shell_in: wp.array(dtype=int), flexvert_norm_in: wp.array2d(dtype=wp.vec3), flex_elemdataadr: wp.array(dtype=int), flex_shelldataadr: wp.array(dtype=int), flex_faceadr: wp.array(dtype=int), flex_radius: wp.array(dtype=float), flex_workadr: wp.array(dtype=int), flex_worknum: wp.array(dtype=int), nfaces: int, smooth: bool, # Out: face_point_out: wp.array(dtype=wp.vec3), ): worldid, workid = wp.tid() # identify which flex this work item belongs to f = int(0) locid = int(0) for i in range(nflex): locid = workid - flex_workadr[i] if locid >= 0 and locid < flex_worknum[i]: f = i break dim = flex_dim[f] face_offset = flex_faceadr[f] world_face_offset = worldid * nfaces vert_adr = flex_vertadr[f] if dim == 2: radius = flex_radius[f] elem_count = flex_elemnum[f] if locid < elem_count: # 2D element faces elemid = locid elem_adr = flex_elemdataadr[f] ebase = elem_adr + elemid * 3 i0 = vert_adr + flex_elem[ebase + 0] i1 = vert_adr + flex_elem[ebase + 1] i2 = vert_adr + flex_elem[ebase + 2] v0 = flexvert_xpos_in[worldid, i0] v1 = flexvert_xpos_in[worldid, i1] v2 = flexvert_xpos_in[worldid, i2] # TODO: Use static conditional if smooth: n0 = flexvert_norm_in[worldid, i0] n1 = flexvert_norm_in[worldid, i1] n2 = flexvert_norm_in[worldid, i2] else: face_nrm = wp.cross(v1 - v0, v2 - v0) face_nrm = wp.normalize(face_nrm) n0 = face_nrm n1 = face_nrm n2 = face_nrm p0_pos = v0 + radius * n0 p1_pos = v1 + radius * n1 p2_pos = v2 + radius * n2 p0_neg = v0 - radius * n0 p1_neg = v1 - radius * n1 p2_neg = v2 - radius * n2 face_id0 = world_face_offset + face_offset + (2 * elemid) base0 = face_id0 * 3 face_point_out[base0 + 0] = p0_pos face_point_out[base0 + 1] = p1_pos face_point_out[base0 + 2] = p2_pos face_id1 = world_face_offset + face_offset + (2 * elemid + 1) base1 = face_id1 * 3 face_point_out[base1 + 0] = p0_neg face_point_out[base1 + 1] = p1_neg face_point_out[base1 + 2] = p2_neg else: # 2D shell faces shellid = locid - elem_count shell_adr = flex_shelldataadr[f] sbase = shell_adr + 2 * shellid i0 = vert_adr + flex_shell_in[sbase + 0] i1 = vert_adr + flex_shell_in[sbase + 1] v0 = flexvert_xpos_in[worldid, i0] v1 = flexvert_xpos_in[worldid, i1] n0 = flexvert_norm_in[worldid, i0] n1 = flexvert_norm_in[worldid, i1] shell_face_offset = face_offset + (2 * elem_count) face_id0 = world_face_offset + shell_face_offset + (2 * shellid) base0 = face_id0 * 3 face_point_out[base0 + 0] = v0 + radius * n0 face_point_out[base0 + 1] = v1 - radius * n1 face_point_out[base0 + 2] = v1 + radius * n1 face_id1 = world_face_offset + shell_face_offset + (2 * shellid + 1) base1 = face_id1 * 3 face_point_out[base1 + 0] = v1 - radius * n1 face_point_out[base1 + 1] = v0 + radius * n0 face_point_out[base1 + 2] = v0 - radius * n0 else: # 3D shell faces shellid = locid shell_adr = flex_shelldataadr[f] sbase = shell_adr + shellid * 3 i0 = vert_adr + flex_shell_in[sbase + 0] i1 = vert_adr + flex_shell_in[sbase + 1] i2 = vert_adr + flex_shell_in[sbase + 2] v0 = flexvert_xpos_in[worldid, i0] v1 = flexvert_xpos_in[worldid, i1] v2 = flexvert_xpos_in[worldid, i2] face_id = world_face_offset + face_offset + shellid fbase = face_id * 3 face_point_out[fbase + 0] = v0 face_point_out[fbase + 1] = v1 face_point_out[fbase + 2] = v2 def build_flex_bvh( mjm: mujoco.MjModel, mjd: mujoco.MjData, nworld: int, constructor: str = "sah", leaf_size: int = 2 ) -> tuple[wp.Mesh, wp.array, wp.array, wp.array, wp.array, wp.array, int]: """Create a Warp mesh BVH from flex data.""" if (mjm.flex_dim == 1).any(): raise ValueError("1D Flex objects are not currently supported.") nflex = mjm.nflex nflexvert = mjm.nflexvert nflexelemdata = len(mjm.flex_elem) flex_elem = wp.array(mjm.flex_elem, dtype=int) flexvert_xpos = wp.array(np.tile(mjd.flexvert_xpos[np.newaxis, :, :], (nworld, 1, 1)), dtype=wp.vec3) flex_faceadr = [0] for f in range(nflex): if mjm.flex_dim[f] == 2: flex_faceadr.append(flex_faceadr[-1] + 2 * mjm.flex_elemnum[f] + 2 * mjm.flex_shellnum[f]) elif mjm.flex_dim[f] == 3: flex_faceadr.append(flex_faceadr[-1] + mjm.flex_shellnum[f]) nface = int(flex_faceadr[-1]) flex_faceadr = flex_faceadr[:-1] face_point = wp.empty(nface * 3 * nworld, dtype=wp.vec3) face_index = wp.empty(nface * 3 * nworld, dtype=wp.int32) group = wp.empty(nface * nworld, dtype=int) flexvert_norm = wp.zeros((nworld, nflexvert), dtype=wp.vec3) flex_shell = wp.array(mjm.flex_shell, dtype=int) wp.launch( kernel=accumulate_flex_vertex_normals, dim=(nworld, nflexelemdata // 3), inputs=[flex_elem, flexvert_xpos], outputs=[flexvert_norm], ) wp.launch( kernel=normalize_vertex_normals, dim=(nworld, nflexvert), inputs=[flexvert_norm], ) for f in range(nflex): dim = mjm.flex_dim[f] elem_adr = mjm.flex_elemdataadr[f] nelem = mjm.flex_elemnum[f] shell_adr = mjm.flex_shelldataadr[f] nshell = mjm.flex_shellnum[f] vert_adr = mjm.flex_vertadr[f] if dim == 2: wp.launch( kernel=_build_flex_2d_elements, dim=(nworld, nelem), inputs=[ flex_elem, flexvert_xpos, flexvert_norm, elem_adr, vert_adr, flex_faceadr[f], mjm.flex_radius[f], nface, ], outputs=[face_point, face_index, group], ) wp.launch( kernel=_build_flex_2d_sides, dim=(nworld, nshell), inputs=[ flexvert_xpos, flexvert_norm, flex_shell, shell_adr, vert_adr, flex_faceadr[f] + 2 * nelem, mjm.flex_radius[f], nface, ], outputs=[face_point, face_index, group], ) elif dim == 3: wp.launch( kernel=_build_flex_3d_shells, dim=(nworld, nshell), inputs=[ flexvert_xpos, flex_shell, shell_adr, vert_adr, flex_faceadr[f], nface, ], outputs=[face_point, face_index, group], ) flex_mesh = wp.Mesh( points=face_point, indices=face_index, groups=group, bvh_constructor=constructor, bvh_leaf_size=leaf_size, ) group_root = wp.empty(nworld, dtype=int) wp.launch( kernel=compute_bvh_group_roots, dim=nworld, inputs=[flex_mesh.id], outputs=[group_root], ) return ( flex_mesh, face_point, group_root, flex_shell, flex_faceadr, nface, ) def refit_flex_bvh(m: Model, d: Data, rc: RenderContext): """Refit the flex BVH.""" flexvert_norm = wp.zeros(d.flexvert_xpos.shape, dtype=wp.vec3) wp.launch( kernel=accumulate_flex_vertex_normals, dim=(d.nworld, m.nflexelemdata // 3), inputs=[ m.flex_elem, d.flexvert_xpos, ], outputs=[flexvert_norm], ) wp.launch( kernel=normalize_vertex_normals, dim=(d.nworld, m.nflexvert), inputs=[flexvert_norm], ) wp.launch( kernel=_update_flex_face_points, dim=(d.nworld, rc.flex_nwork), inputs=[ m.nflex, m.flex_dim, m.flex_vertadr, m.flex_elemnum, m.flex_elem, d.flexvert_xpos, rc.flex_shell, flexvert_norm, rc.flex_elemdataadr, rc.flex_shelldataadr, rc.flex_faceadr, rc.flex_radius, rc.flex_workadr, rc.flex_worknum, rc.flex_nface, rc.flex_render_smooth, ], outputs=[rc.flex_face_point], ) rc.flex_mesh.refit()