# Copyright 2025 The Newton Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from typing import Any
import warp as wp
from mujoco_warp._src.collision_convex import convex_narrowphase
from mujoco_warp._src.collision_primitive import primitive_narrowphase
from mujoco_warp._src.collision_sdf import sdf_narrowphase
from mujoco_warp._src.io import BLEEDING_EDGE_MUJOCO
from mujoco_warp._src.math import upper_tri_index
from mujoco_warp._src.types import MJ_MAXVAL
from mujoco_warp._src.types import BroadphaseFilter
from mujoco_warp._src.types import BroadphaseType
from mujoco_warp._src.types import CollisionContext
from mujoco_warp._src.types import CollisionType
from mujoco_warp._src.types import Data
from mujoco_warp._src.types import DisableBit
from mujoco_warp._src.types import GeomType
from mujoco_warp._src.types import Model
from mujoco_warp._src.types import mat23
from mujoco_warp._src.types import mat63
from mujoco_warp._src.warp_util import cache_kernel
from mujoco_warp._src.warp_util import event_scope
wp.set_module_options({"enable_backward": False})
# Corresponding table to MuJoCo's mjCOLLISIONFUNC table in engine_collision_driver.c
MJ_COLLISION_TABLE = {
(GeomType.PLANE, GeomType.SPHERE): CollisionType.PRIMITIVE,
(GeomType.PLANE, GeomType.CAPSULE): CollisionType.PRIMITIVE,
(GeomType.PLANE, GeomType.ELLIPSOID): CollisionType.PRIMITIVE,
(GeomType.PLANE, GeomType.CYLINDER): CollisionType.PRIMITIVE,
(GeomType.PLANE, GeomType.BOX): CollisionType.PRIMITIVE,
(GeomType.PLANE, GeomType.MESH): CollisionType.PRIMITIVE,
(GeomType.HFIELD, GeomType.SPHERE): CollisionType.CONVEX,
(GeomType.HFIELD, GeomType.CAPSULE): CollisionType.CONVEX,
(GeomType.HFIELD, GeomType.ELLIPSOID): CollisionType.CONVEX,
(GeomType.HFIELD, GeomType.CYLINDER): CollisionType.CONVEX,
(GeomType.HFIELD, GeomType.BOX): CollisionType.CONVEX,
(GeomType.HFIELD, GeomType.MESH): CollisionType.CONVEX,
(GeomType.SPHERE, GeomType.SPHERE): CollisionType.PRIMITIVE,
(GeomType.SPHERE, GeomType.CAPSULE): CollisionType.PRIMITIVE,
(GeomType.SPHERE, GeomType.ELLIPSOID): CollisionType.CONVEX,
(GeomType.SPHERE, GeomType.CYLINDER): CollisionType.PRIMITIVE,
(GeomType.SPHERE, GeomType.BOX): CollisionType.PRIMITIVE,
(GeomType.SPHERE, GeomType.MESH): CollisionType.CONVEX,
(GeomType.CAPSULE, GeomType.CAPSULE): CollisionType.PRIMITIVE,
(GeomType.CAPSULE, GeomType.ELLIPSOID): CollisionType.CONVEX,
(GeomType.CAPSULE, GeomType.CYLINDER): CollisionType.CONVEX,
(GeomType.CAPSULE, GeomType.BOX): CollisionType.PRIMITIVE,
(GeomType.CAPSULE, GeomType.MESH): CollisionType.CONVEX,
(GeomType.ELLIPSOID, GeomType.ELLIPSOID): CollisionType.CONVEX,
(GeomType.ELLIPSOID, GeomType.CYLINDER): CollisionType.CONVEX,
(GeomType.ELLIPSOID, GeomType.BOX): CollisionType.CONVEX,
(GeomType.ELLIPSOID, GeomType.MESH): CollisionType.CONVEX,
(GeomType.CYLINDER, GeomType.CYLINDER): CollisionType.CONVEX,
(GeomType.CYLINDER, GeomType.BOX): CollisionType.CONVEX,
(GeomType.CYLINDER, GeomType.MESH): CollisionType.CONVEX,
(GeomType.BOX, GeomType.BOX): CollisionType.CONVEX, # overwritten by NATIVECCD disable flag
(GeomType.BOX, GeomType.MESH): CollisionType.CONVEX,
(GeomType.MESH, GeomType.MESH): CollisionType.CONVEX,
}
def create_collision_context(naconmax: int) -> CollisionContext:
"""Create a CollisionContext with allocated arrays."""
return CollisionContext(
collision_pair=wp.empty(naconmax, dtype=wp.vec2i),
collision_pairid=wp.empty(naconmax, dtype=wp.vec2i),
collision_worldid=wp.empty(naconmax, dtype=int),
)
@wp.kernel
def _zero_nacon_ncollision(
# Data out:
nacon_out: wp.array(dtype=int),
ncollision_out: wp.array(dtype=int),
):
ncollision_out[0] = 0
nacon_out[0] = 0
@wp.func
def _plane_filter(
size1: float, size2: float, margin1: float, margin2: float, xpos1: wp.vec3, xpos2: wp.vec3, xmat1: wp.mat33, xmat2: wp.mat33
) -> bool:
if size1 == 0.0:
# geom1 is a plane
dist = wp.dot(xpos2 - xpos1, wp.vec3(xmat1[0, 2], xmat1[1, 2], xmat1[2, 2]))
if BLEEDING_EDGE_MUJOCO:
return dist <= size2 + margin1 + margin2
else:
return dist <= size2 + wp.max(margin1, margin2)
elif size2 == 0.0:
# geom2 is a plane
dist = wp.dot(xpos1 - xpos2, wp.vec3(xmat2[0, 2], xmat2[1, 2], xmat2[2, 2]))
if BLEEDING_EDGE_MUJOCO:
return dist <= size1 + margin1 + margin2
else:
return dist <= size1 + wp.max(margin1, margin2)
return True
@wp.func
def _sphere_filter(size1: float, size2: float, margin1: float, margin2: float, xpos1: wp.vec3, xpos2: wp.vec3) -> bool:
if BLEEDING_EDGE_MUJOCO:
bound = size1 + size2 + margin1 + margin2
else:
bound = size1 + size2 + wp.max(margin1, margin2)
dif = xpos2 - xpos1
dist_sq = wp.dot(dif, dif)
return dist_sq <= bound * bound
# TODO(team): improve performance by precomputing bounding box
@wp.func
def _aabb_filter(
# In:
center1: wp.vec3,
center2: wp.vec3,
size1: wp.vec3,
size2: wp.vec3,
margin1: float,
margin2: float,
xpos1: wp.vec3,
xpos2: wp.vec3,
xmat1: wp.mat33,
xmat2: wp.mat33,
) -> bool:
"""Axis aligned boxes collision.
references: see Ericson, Real-time Collision Detection section 4.2.
filterBox: filter contact based on global AABBs.
"""
center1 = xmat1 @ center1 + xpos1
center2 = xmat2 @ center2 + xpos2
if BLEEDING_EDGE_MUJOCO:
margin = margin1 + margin2
else:
margin = wp.max(margin1, margin2)
max_x1 = -MJ_MAXVAL
max_y1 = -MJ_MAXVAL
max_z1 = -MJ_MAXVAL
min_x1 = MJ_MAXVAL
min_y1 = MJ_MAXVAL
min_z1 = MJ_MAXVAL
max_x2 = -MJ_MAXVAL
max_y2 = -MJ_MAXVAL
max_z2 = -MJ_MAXVAL
min_x2 = MJ_MAXVAL
min_y2 = MJ_MAXVAL
min_z2 = MJ_MAXVAL
sign = wp.vec2(-1.0, 1.0)
for i in range(2):
for j in range(2):
for k in range(2):
corner1 = wp.vec3(sign[i] * size1[0], sign[j] * size1[1], sign[k] * size1[2])
pos1 = xmat1 @ corner1
corner2 = wp.vec3(sign[i] * size2[0], sign[j] * size2[1], sign[k] * size2[2])
pos2 = xmat2 @ corner2
if pos1[0] > max_x1:
max_x1 = pos1[0]
if pos1[1] > max_y1:
max_y1 = pos1[1]
if pos1[2] > max_z1:
max_z1 = pos1[2]
if pos1[0] < min_x1:
min_x1 = pos1[0]
if pos1[1] < min_y1:
min_y1 = pos1[1]
if pos1[2] < min_z1:
min_z1 = pos1[2]
if pos2[0] > max_x2:
max_x2 = pos2[0]
if pos2[1] > max_y2:
max_y2 = pos2[1]
if pos2[2] > max_z2:
max_z2 = pos2[2]
if pos2[0] < min_x2:
min_x2 = pos2[0]
if pos2[1] < min_y2:
min_y2 = pos2[1]
if pos2[2] < min_z2:
min_z2 = pos2[2]
if center1[0] + max_x1 + margin < center2[0] + min_x2:
return False
if center1[1] + max_y1 + margin < center2[1] + min_y2:
return False
if center1[2] + max_z1 + margin < center2[2] + min_z2:
return False
if center2[0] + max_x2 + margin < center1[0] + min_x1:
return False
if center2[1] + max_y2 + margin < center1[1] + min_y1:
return False
if center2[2] + max_z2 + margin < center1[2] + min_z1:
return False
return True
# TODO(team): improve performance by precomputing bounding box
@wp.func
def _obb_filter(
# In:
center1: wp.vec3,
center2: wp.vec3,
size1: wp.vec3,
size2: wp.vec3,
margin1: float,
margin2: float,
xpos1: wp.vec3,
xpos2: wp.vec3,
xmat1: wp.mat33,
xmat2: wp.mat33,
) -> bool:
"""Oriented bounding boxes collision (see Gottschalk et al.), see mj_collideOBB."""
if BLEEDING_EDGE_MUJOCO:
margin = margin1 + margin2
else:
margin = wp.max(margin1, margin2)
xcenter = mat23()
normal = mat63()
proj = wp.vec2()
radius = wp.vec2()
# compute centers in local coordinates
xcenter[0] = xmat1 @ center1 + xpos1
xcenter[1] = xmat2 @ center2 + xpos2
# compute normals in global coordinates
normal[0] = wp.vec3(xmat1[0, 0], xmat1[1, 0], xmat1[2, 0])
normal[1] = wp.vec3(xmat1[0, 1], xmat1[1, 1], xmat1[2, 1])
normal[2] = wp.vec3(xmat1[0, 2], xmat1[1, 2], xmat1[2, 2])
normal[3] = wp.vec3(xmat2[0, 0], xmat2[1, 0], xmat2[2, 0])
normal[4] = wp.vec3(xmat2[0, 1], xmat2[1, 1], xmat2[2, 1])
normal[5] = wp.vec3(xmat2[0, 2], xmat2[1, 2], xmat2[2, 2])
# check intersections
for j in range(2):
for k in range(3):
for i in range(2):
proj[i] = wp.dot(xcenter[i], normal[3 * j + k])
if i == 0:
size = size1
else:
size = size2
# fmt: off
radius[i] = (
wp.abs(size[0] * wp.dot(normal[3 * i + 0], normal[3 * j + k]))
+ wp.abs(size[1] * wp.dot(normal[3 * i + 1], normal[3 * j + k]))
+ wp.abs(size[2] * wp.dot(normal[3 * i + 2], normal[3 * j + k]))
)
# fmt: on
if radius[0] + radius[1] + margin < wp.abs(proj[1] - proj[0]):
return False
return True
def _broadphase_filter(opt_broadphase_filter: int, ngeom_aabb: int, ngeom_rbound: int, ngeom_margin: int):
@wp.func
def func(
# Model:
geom_aabb: wp.array3d(dtype=wp.vec3),
geom_rbound: wp.array2d(dtype=float),
geom_margin: wp.array2d(dtype=float),
# Data in:
geom_xpos_in: wp.array2d(dtype=wp.vec3),
geom_xmat_in: wp.array2d(dtype=wp.mat33),
# In:
geom1: int,
geom2: int,
worldid: int,
) -> bool:
# 1: plane
# 2: sphere
# 4: aabb
# 8: obb
aabb_id = worldid % ngeom_aabb if wp.static(ngeom_aabb > 1) else 0
center1, center2 = geom_aabb[aabb_id, geom1, 0], geom_aabb[aabb_id, geom2, 0]
size1, size2 = geom_aabb[aabb_id, geom1, 1], geom_aabb[aabb_id, geom2, 1]
rbound_id = worldid % ngeom_rbound if wp.static(ngeom_rbound > 1) else 0
rbound1, rbound2 = geom_rbound[rbound_id, geom1], geom_rbound[rbound_id, geom2]
margin_id = worldid % ngeom_margin if wp.static(ngeom_margin > 1) else 0
margin1, margin2 = geom_margin[margin_id, geom1], geom_margin[margin_id, geom2]
xpos1, xpos2 = geom_xpos_in[worldid, geom1], geom_xpos_in[worldid, geom2]
xmat1, xmat2 = geom_xmat_in[worldid, geom1], geom_xmat_in[worldid, geom2]
if rbound1 == 0.0 or rbound2 == 0.0:
if wp.static(opt_broadphase_filter & BroadphaseFilter.PLANE):
return _plane_filter(rbound1, rbound2, margin1, margin2, xpos1, xpos2, xmat1, xmat2)
else:
if wp.static(opt_broadphase_filter & BroadphaseFilter.SPHERE):
if not _sphere_filter(rbound1, rbound2, margin1, margin2, xpos1, xpos2):
return False
if wp.static(opt_broadphase_filter & BroadphaseFilter.AABB):
if not _aabb_filter(center1, center2, size1, size2, margin1, margin2, xpos1, xpos2, xmat1, xmat2):
return False
if wp.static(opt_broadphase_filter & BroadphaseFilter.OBB):
if not _obb_filter(center1, center2, size1, size2, margin1, margin2, xpos1, xpos2, xmat1, xmat2):
return False
return True
return func
@wp.func
def _add_geom_pair(
# Model:
geom_type: wp.array(dtype=int),
nxn_pairid: wp.array(dtype=wp.vec2i),
# Data in:
naconmax_in: int,
# In:
geom1: int,
geom2: int,
worldid: int,
nxnid: int,
# Data out:
ncollision_out: wp.array(dtype=int),
# Out:
collision_pair_out: wp.array(dtype=wp.vec2i),
collision_pairid_out: wp.array(dtype=wp.vec2i),
collision_worldid_out: wp.array(dtype=int),
):
pairid = wp.atomic_add(ncollision_out, 0, 1)
if pairid >= naconmax_in:
return
type1 = geom_type[geom1]
type2 = geom_type[geom2]
if type1 > type2:
pair = wp.vec2i(geom2, geom1)
else:
pair = wp.vec2i(geom1, geom2)
collision_pair_out[pairid] = pair
collision_pairid_out[pairid] = nxn_pairid[nxnid]
collision_worldid_out[pairid] = worldid
@wp.func
def _binary_search(values: wp.array(dtype=Any), value: Any, lower: int, upper: int) -> int:
while lower < upper:
mid = (lower + upper) >> 1
if values[mid] > value:
upper = mid
else:
lower = mid + 1
return upper
def _sap_project(opt_broadphase: int):
@wp.kernel(module="unique", enable_backward=False)
def sap_project(
# Model:
ngeom: int,
geom_rbound: wp.array2d(dtype=float),
geom_margin: wp.array2d(dtype=float),
# Data in:
geom_xpos_in: wp.array2d(dtype=wp.vec3),
nworld_in: int,
# In:
direction_in: wp.vec3,
# Out:
projection_lower_out: wp.array2d(dtype=float),
projection_upper_out: wp.array2d(dtype=float),
sort_index_out: wp.array2d(dtype=int),
segmented_index_out: wp.array(dtype=int),
):
worldid, geomid = wp.tid()
xpos = geom_xpos_in[worldid, geomid]
rbound = geom_rbound[worldid % geom_rbound.shape[0], geomid]
if rbound == 0.0:
# geom is a plane
rbound = MJ_MAXVAL
radius = rbound + geom_margin[worldid % geom_margin.shape[0], geomid]
center = wp.dot(direction_in, xpos)
sort_index_out[worldid, geomid] = geomid
if not wp.isnan(center):
projection_lower_out[worldid, geomid] = center - radius
projection_upper_out[worldid, geomid] = center + radius
else:
projection_lower_out[worldid, geomid] = MJ_MAXVAL
projection_upper_out[worldid, geomid] = MJ_MAXVAL
if wp.static(opt_broadphase == BroadphaseType.SAP_SEGMENTED):
if geomid == 0:
segmented_index_out[worldid] = worldid * ngeom
if worldid == nworld_in - 1:
segmented_index_out[nworld_in] = nworld_in * ngeom
return sap_project
@wp.kernel
def _sap_range(
# Model:
ngeom: int,
# In:
projection_lower_in: wp.array2d(dtype=float),
projection_upper_in: wp.array2d(dtype=float),
sort_index_in: wp.array2d(dtype=int),
# Out:
range_out: wp.array2d(dtype=int),
):
worldid, geomid = wp.tid()
# current bounding geom
idx = sort_index_in[worldid, geomid]
upper = projection_upper_in[worldid, idx]
limit = _binary_search(projection_lower_in[worldid], upper, geomid + 1, ngeom)
limit = wp.min(ngeom - 1, limit)
# range of geoms for the sweep and prune process
range_out[worldid, geomid] = limit - geomid
@cache_kernel
def _sap_broadphase(opt_broadphase_filter: int, ngeom_aabb: int, ngeom_rbound: int, ngeom_margin: int):
@wp.kernel(module="unique", enable_backward=False)
def kernel(
# Model:
ngeom: int,
geom_type: wp.array(dtype=int),
geom_aabb: wp.array3d(dtype=wp.vec3),
geom_rbound: wp.array2d(dtype=float),
geom_margin: wp.array2d(dtype=float),
nxn_pairid: wp.array(dtype=wp.vec2i),
# Data in:
geom_xpos_in: wp.array2d(dtype=wp.vec3),
geom_xmat_in: wp.array2d(dtype=wp.mat33),
nworld_in: int,
naconmax_in: int,
# In:
sort_index_in: wp.array2d(dtype=int),
cumulative_sum_in: wp.array(dtype=int),
nsweep_in: int,
# Data out:
ncollision_out: wp.array(dtype=int),
# Out:
collision_pair_out: wp.array(dtype=wp.vec2i),
collision_pairid_out: wp.array(dtype=wp.vec2i),
collision_worldid_out: wp.array(dtype=int),
):
worldgeomid = wp.tid()
nworldgeom = nworld_in * ngeom
nworkpackages = cumulative_sum_in[nworldgeom - 1]
while worldgeomid < nworkpackages:
# binary search to find current and next geom pair indices
i = _binary_search(cumulative_sum_in, worldgeomid, 0, nworldgeom)
j = i + worldgeomid + 1
if i > 0:
j -= cumulative_sum_in[i - 1]
worldid = i // ngeom
i = i % ngeom
j = j % ngeom
# get geom indices and swap if necessary
geom1 = sort_index_in[worldid, i]
geom2 = sort_index_in[worldid, j]
# find linear index of (geom1, geom2) in upper triangular nxn_pairid
if geom2 < geom1:
idx = upper_tri_index(ngeom, geom2, geom1)
else:
idx = upper_tri_index(ngeom, geom1, geom2)
worldgeomid += nsweep_in
pairid = nxn_pairid[idx]
if pairid[0] < -1 and pairid[1] < 0:
continue
if (
wp.static(_broadphase_filter(opt_broadphase_filter, ngeom_aabb, ngeom_rbound, ngeom_margin))(
geom_aabb, geom_rbound, geom_margin, geom_xpos_in, geom_xmat_in, geom1, geom2, worldid
)
or pairid[1] >= 0
):
_add_geom_pair(
geom_type,
nxn_pairid,
naconmax_in,
geom1,
geom2,
worldid,
idx,
ncollision_out,
collision_pair_out,
collision_pairid_out,
collision_worldid_out,
)
return kernel
def _segmented_sort(tile_size: int):
@wp.kernel(module="unique")
def segmented_sort(
# In:
projection_lower_in: wp.array2d(dtype=float),
sort_index_in: wp.array2d(dtype=int),
# Out:
projection_lower_out: wp.array2d(dtype=float),
sort_index_out: wp.array2d(dtype=int),
):
worldid = wp.tid()
# Load input into shared memory
keys = wp.tile_load(projection_lower_in[worldid], shape=tile_size, storage="shared")
values = wp.tile_load(sort_index_in[worldid], shape=tile_size, storage="shared")
# Perform in-place sorting
wp.tile_sort(keys, values)
# Store sorted shared memory into output arrays
wp.tile_store(projection_lower_out[worldid], keys)
wp.tile_store(sort_index_out[worldid], values)
return segmented_sort
[docs]
@event_scope
def sap_broadphase(m: Model, d: Data, ctx: CollisionContext):
"""Runs broadphase collision detection using a sweep-and-prune (SAP) algorithm.
This method is more efficient than the N-squared approach for large numbers of
objects. It works by projecting the bounding spheres of all geoms onto a
single axis and sorting them. It then sweeps along the axis, only checking
for overlaps between geoms whose projections are close to each other.
For each potentially colliding pair identified by the sweep, a more precise
bounding sphere check is performed. If this check passes, the pair is added
to the collision arrays in `d` for the narrowphase stage.
Two sorting strategies are supported, controlled by `m.opt.broadphase`
- `SAP_TILE`: Uses a tile-based sort.
- `SAP_SEGMENTED`: Uses a segmented sort.
"""
nworldgeom = d.nworld * m.ngeom
# TODO(team): direction
# random fixed direction
direction = wp.vec3(0.5935, 0.7790, 0.1235)
direction = wp.normalize(direction)
projection_lower = wp.empty((d.nworld, m.ngeom, 2), dtype=float)
projection_upper = wp.empty((d.nworld, m.ngeom), dtype=float)
sort_index = wp.empty((d.nworld, m.ngeom, 2), dtype=int)
range_ = wp.empty((d.nworld, m.ngeom), dtype=int)
cumulative_sum = wp.empty((d.nworld, m.ngeom), dtype=int)
segmented_index = wp.empty(d.nworld + 1 if m.opt.broadphase == BroadphaseType.SAP_SEGMENTED else 0, dtype=int)
wp.launch(
kernel=_sap_project(m.opt.broadphase),
dim=(d.nworld, m.ngeom),
inputs=[m.ngeom, m.geom_rbound, m.geom_margin, d.geom_xpos, d.nworld, direction],
outputs=[
projection_lower.reshape((-1, m.ngeom)),
projection_upper,
sort_index.reshape((-1, m.ngeom)),
segmented_index,
],
)
if m.opt.broadphase == BroadphaseType.SAP_TILE:
wp.launch_tiled(
kernel=_segmented_sort(m.ngeom),
dim=d.nworld,
inputs=[projection_lower.reshape((-1, m.ngeom)), sort_index.reshape((-1, m.ngeom))],
outputs=[projection_lower.reshape((-1, m.ngeom)), sort_index.reshape((-1, m.ngeom))],
block_dim=m.block_dim.segmented_sort,
)
else:
wp.utils.segmented_sort_pairs(
projection_lower.reshape((-1, m.ngeom)), sort_index.reshape((-1, m.ngeom)), nworldgeom, segmented_index
)
wp.launch(
kernel=_sap_range,
dim=(d.nworld, m.ngeom),
inputs=[m.ngeom, projection_lower.reshape((-1, m.ngeom)), projection_upper, sort_index.reshape((-1, m.ngeom))],
outputs=[range_],
)
# scan is used for load balancing among the threads
wp.utils.array_scan(range_.reshape(-1), cumulative_sum.reshape(-1), True)
# estimate number of overlap checks
# assumes each geom has 5 other geoms (batched over all worlds)
nsweep = 5 * nworldgeom
wp.launch(
kernel=_sap_broadphase(m.opt.broadphase_filter, m.geom_aabb.shape[0], m.geom_rbound.shape[0], m.geom_margin.shape[0]),
dim=nsweep,
inputs=[
m.ngeom,
m.geom_type,
m.geom_aabb,
m.geom_rbound,
m.geom_margin,
m.nxn_pairid,
d.geom_xpos,
d.geom_xmat,
d.nworld,
d.naconmax,
sort_index.reshape((-1, m.ngeom)),
cumulative_sum.reshape(-1),
nsweep,
],
outputs=[d.ncollision, ctx.collision_pair, ctx.collision_pairid, ctx.collision_worldid],
)
@cache_kernel
def _nxn_broadphase(opt_broadphase_filter: int, ngeom_aabb: int, ngeom_rbound: int, ngeom_margin: int):
@wp.kernel(module="unique", enable_backward=False)
def kernel(
# Model:
geom_type: wp.array(dtype=int),
geom_aabb: wp.array3d(dtype=wp.vec3),
geom_rbound: wp.array2d(dtype=float),
geom_margin: wp.array2d(dtype=float),
nxn_geom_pair: wp.array(dtype=wp.vec2i),
nxn_pairid: wp.array(dtype=wp.vec2i),
# Data in:
geom_xpos_in: wp.array2d(dtype=wp.vec3),
geom_xmat_in: wp.array2d(dtype=wp.mat33),
naconmax_in: int,
# Data out:
ncollision_out: wp.array(dtype=int),
# Out:
collision_pair_out: wp.array(dtype=wp.vec2i),
collision_pairid_out: wp.array(dtype=wp.vec2i),
collision_worldid_out: wp.array(dtype=int),
):
worldid, elementid = wp.tid()
geom = nxn_geom_pair[elementid]
geom1 = geom[0]
geom2 = geom[1]
if (
wp.static(_broadphase_filter(opt_broadphase_filter, ngeom_aabb, ngeom_rbound, ngeom_margin))(
geom_aabb, geom_rbound, geom_margin, geom_xpos_in, geom_xmat_in, geom1, geom2, worldid
)
or nxn_pairid[elementid][1] >= 0
):
_add_geom_pair(
geom_type,
nxn_pairid,
naconmax_in,
geom1,
geom2,
worldid,
elementid,
ncollision_out,
collision_pair_out,
collision_pairid_out,
collision_worldid_out,
)
return kernel
[docs]
@event_scope
def nxn_broadphase(m: Model, d: Data, ctx: CollisionContext):
"""Runs broadphase collision detection using a brute-force N-squared approach.
This function iterates through a pre-filtered list of all possible geometry pairs and
performs a quick bounding sphere check to identify potential collisions.
For each pair that passes the sphere check, it populates the collision arrays in `d`
(`d.collision_pair`, `d.collision_pairid`, etc.), which are then consumed by the
narrowphase.
The initial list of pairs is filtered at model creation time to exclude pairs based on
`contype`/`conaffinity`, parent-child relationships, and explicit `<exclude>` tags.
"""
wp.launch(
_nxn_broadphase(m.opt.broadphase_filter, m.geom_aabb.shape[0], m.geom_rbound.shape[0], m.geom_margin.shape[0]),
dim=(d.nworld, m.nxn_geom_pair_filtered.shape[0]),
inputs=[
m.geom_type,
m.geom_aabb,
m.geom_rbound,
m.geom_margin,
m.nxn_geom_pair_filtered,
m.nxn_pairid_filtered,
d.geom_xpos,
d.geom_xmat,
d.naconmax,
],
outputs=[
d.ncollision,
ctx.collision_pair,
ctx.collision_pairid,
ctx.collision_worldid,
],
)
def _narrowphase(m: Model, d: Data, ctx: CollisionContext):
collision_table = MJ_COLLISION_TABLE
if m.opt.disableflags & DisableBit.NATIVECCD:
collision_table[(GeomType.BOX, GeomType.BOX)] = CollisionType.PRIMITIVE
convex_pairs = [key for key, value in collision_table.items() if value == CollisionType.CONVEX]
primitive_pairs = [key for key, value in collision_table.items() if value == CollisionType.PRIMITIVE]
# TODO(team): we should reject far-away contacts in the narrowphase instead of constraint
# partitioning because we can move some pressure of the atomics
convex_narrowphase(m, d, ctx, convex_pairs)
primitive_narrowphase(m, d, ctx, primitive_pairs)
if m.has_sdf_geom:
sdf_narrowphase(m, d, ctx)
[docs]
@event_scope
def collision(m: Model, d: Data):
"""Runs the full collision detection pipeline.
This function orchestrates the broadphase and narrowphase collision detection stages. It
first identifies potential collision pairs using a broadphase algorithm (either N-squared
or Sweep-and-Prune, based on `m.opt.broadphase`). Then, for each potential pair, it
performs narrowphase collision detection to compute detailed contact information like
distance, position, and frame.
The results are used to populate the `d.contact` array, and the total number of contacts
is stored in `d.nacon`. If `d.nacon` is larger than `d.naconmax` then an overflow has
occurred and the remaining contacts will be skipped. If this happens, raise the `nconmax`
parameter in `io.make_data` or `io.put_data`.
This function will do nothing except zero out arrays if collision detection is disabled
via `m.opt.disableflags` or if `d.nacon` is 0.
"""
if d.naconmax == 0 or m.opt.disableflags & (DisableBit.CONSTRAINT | DisableBit.CONTACT):
d.nacon.zero_()
return
ctx = create_collision_context(d.naconmax)
# zero counters
wp.launch(_zero_nacon_ncollision, dim=1, outputs=[d.nacon, d.ncollision])
if m.opt.broadphase == BroadphaseType.NXN:
nxn_broadphase(m, d, ctx)
else:
sap_broadphase(m, d, ctx)
_narrowphase(m, d, ctx)