# Copyright 2025 The Newton Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import dataclasses
import importlib.metadata
import warnings
from typing import Any, Optional, Sequence
import mujoco
import numpy as np
import packaging.version
import warp as wp
from mujoco_warp._src import bvh
from mujoco_warp._src import render_util
from mujoco_warp._src import smooth
from mujoco_warp._src import types
from mujoco_warp._src import warp_util
def _is_mujoco_dev() -> bool:
"""Checks if mujoco version is > 3.4.0."""
version_str = getattr(mujoco, "__version__", None)
if not version_str:
version_str = importlib.metadata.version("mujoco")
version_str = version_str.split("-")[0].split(".dev")[0]
return packaging.version.parse(version_str) > packaging.version.parse("3.4.0")
BLEEDING_EDGE_MUJOCO = _is_mujoco_dev()
def _create_array(data: Any, spec: wp.array, sizes: dict[str, int]) -> wp.array | None:
"""Creates a warp array and populates it with data.
The array shape is determined by a field spec referencing MjModel / MjData array sizes.
"""
shape = None
if spec.shape != (0,):
shape = tuple(sizes[dim] if isinstance(dim, str) else dim for dim in spec.shape)
if data is None and shape is None:
return None # nothing to do
elif data is None:
array = wp.zeros(shape, dtype=spec.dtype)
else:
array = wp.array(np.array(data), dtype=spec.dtype, shape=shape)
if spec.shape[0] == "*":
# add private attribute for JAX to determine which fields are batched
array._is_batched = True
# also set stride 0 to 0 which is expected legacy behavior (but is deprecated)
array.strides = (0,) + array.strides[1:]
return array
def is_sparse(mjm: mujoco.MjModel) -> bool:
if mjm.opt.jacobian == mujoco.mjtJacobian.mjJAC_AUTO:
if mjm.nv > 32:
return True
else:
return False
else:
return bool(mujoco.mj_isSparse(mjm))
[docs]
def put_model(mjm: mujoco.MjModel) -> types.Model:
"""Creates a model on device.
Args:
mjm: The model containing kinematic and dynamic information (host).
Returns:
The model containing kinematic and dynamic information (device).
"""
# check for compatible cuda toolkit and driver versions
warp_util.check_toolkit_driver()
# model: check supported features in array types
for field, field_type, mj_type in (
(mjm.actuator_trntype, types.TrnType, mujoco.mjtTrn),
(mjm.actuator_dyntype, types.DynType, mujoco.mjtDyn),
(mjm.actuator_gaintype, types.GainType, mujoco.mjtGain),
(mjm.actuator_biastype, types.BiasType, mujoco.mjtBias),
(mjm.eq_type, types.EqType, mujoco.mjtEq),
(mjm.geom_type, types.GeomType, mujoco.mjtGeom),
(mjm.sensor_type, types.SensorType, mujoco.mjtSensor),
(mjm.wrap_type, types.WrapType, mujoco.mjtWrap),
):
missing = ~np.isin(field, field_type)
if missing.any():
names = [mj_type(v).name for v in field[missing]]
raise NotImplementedError(f"{names} not supported.")
# opt: check supported features in scalar types
for field, field_type, mj_type in (
(mjm.opt.integrator, types.IntegratorType, mujoco.mjtIntegrator),
(mjm.opt.cone, types.ConeType, mujoco.mjtCone),
(mjm.opt.solver, types.SolverType, mujoco.mjtSolver),
):
if field not in set(field_type):
raise NotImplementedError(f"{mj_type(field).name} is unsupported.")
# opt: check supported features in scalar flag types
for field, field_type, mj_type in (
(mjm.opt.disableflags, types.DisableBit, mujoco.mjtDisableBit),
(mjm.opt.enableflags, types.EnableBit, mujoco.mjtEnableBit),
):
unsupported = field & ~np.bitwise_or.reduce(field_type)
if unsupported:
raise NotImplementedError(f"{mj_type(unsupported).name} is unsupported.")
if ((mjm.flex_contype != 0) | (mjm.flex_conaffinity != 0)).any():
raise NotImplementedError("Flex collisions are not implemented.")
if mjm.opt.noslip_iterations > 0:
raise NotImplementedError(f"noslip solver not implemented.")
if (mjm.opt.viscosity > 0 or mjm.opt.density > 0) and mjm.opt.integrator in (
mujoco.mjtIntegrator.mjINT_IMPLICITFAST,
mujoco.mjtIntegrator.mjINT_IMPLICIT,
):
raise NotImplementedError(f"Implicit integrators and fluid model not implemented.")
if (mjm.body_plugin != -1).any():
raise NotImplementedError("Body plugins not supported.")
if (mjm.actuator_plugin != -1).any():
raise NotImplementedError("Actuator plugins not supported.")
if (mjm.sensor_plugin != -1).any():
raise NotImplementedError("Sensor plugins not supported.")
# TODO(team): remove after _update_gradient for Newton uses tile operations for islands
nv_max = 60
if mjm.nv > nv_max and mjm.opt.jacobian == mujoco.mjtJacobian.mjJAC_DENSE:
raise ValueError(f"Dense is unsupported for nv > {nv_max} (nv = {mjm.nv}).")
collision_sensors = (mujoco.mjtSensor.mjSENS_GEOMDIST, mujoco.mjtSensor.mjSENS_GEOMNORMAL, mujoco.mjtSensor.mjSENS_GEOMFROMTO)
is_collision_sensor = np.isin(mjm.sensor_type, collision_sensors)
def not_implemented(objtype, objid, geomtype):
if objtype == mujoco.mjtObj.mjOBJ_BODY:
geomnum = mjm.body_geomnum[objid]
geomadr = mjm.body_geomadr[objid]
for geomid in range(geomadr, geomadr + geomnum):
if mjm.geom_type[geomid] == geomtype:
return True
elif objtype == mujoco.mjtObj.mjOBJ_GEOM:
if mjm.geom_type[objid] == geomtype:
return True
return False
def _check_friction(name: str, id_: int, condim: int, friction, checks):
for min_condim, indices in checks:
if condim >= min_condim:
for idx in indices:
if friction[idx] < types.MJ_MINMU:
warnings.warn(
f"{name} {id_}: friction[{idx}] ({friction[idx]}) < MJ_MINMU ({types.MJ_MINMU}) with condim={condim} may cause NaN"
)
for geomid in range(mjm.ngeom):
_check_friction("geom", geomid, mjm.geom_condim[geomid], mjm.geom_friction[geomid], [(3, [0]), (4, [1]), (6, [2])])
for pairid in range(mjm.npair):
_check_friction("pair", pairid, mjm.pair_dim[pairid], mjm.pair_friction[pairid], [(3, [0]), (4, [1, 2]), (6, [3, 4])])
# create opt
opt_kwargs = {f.name: getattr(mjm.opt, f.name, None) for f in dataclasses.fields(types.Option)}
if hasattr(mjm.opt, "impratio"):
opt_kwargs["impratio_invsqrt"] = 1.0 / np.sqrt(np.maximum(mjm.opt.impratio, mujoco.mjMINVAL))
opt = types.Option(**opt_kwargs)
# C MuJoCo tolerance was chosen for float64 architecture, but we default to float32 on GPU
# adjust the tolerance for lower precision, to avoid the solver spending iterations needlessly
# bouncing around the optimal solution
opt.tolerance = max(opt.tolerance, 1e-6)
# warp only fields
ls_parallel_id = mujoco.mj_name2id(mjm, mujoco.mjtObj.mjOBJ_NUMERIC, "ls_parallel")
opt.ls_parallel = (ls_parallel_id > -1) and (mjm.numeric_data[mjm.numeric_adr[ls_parallel_id]] == 1)
opt.ls_parallel_min_step = 1.0e-6 # TODO(team): determine good default setting
opt.broadphase = types.BroadphaseType.NXN
opt.broadphase_filter = types.BroadphaseFilter.PLANE | types.BroadphaseFilter.SPHERE | types.BroadphaseFilter.OBB
opt.graph_conditional = True
opt.run_collision_detection = True
contact_sensor_maxmatch_id = mujoco.mj_name2id(mjm, mujoco.mjtObj.mjOBJ_NUMERIC, "contact_sensor_maxmatch")
if contact_sensor_maxmatch_id > -1:
opt.contact_sensor_maxmatch = mjm.numeric_data[mjm.numeric_adr[contact_sensor_maxmatch_id]]
else:
opt.contact_sensor_maxmatch = 64
# place opt on device
for f in dataclasses.fields(types.Option):
if isinstance(f.type, wp.array):
setattr(opt, f.name, _create_array(getattr(opt, f.name), f.type, {"*": 1}))
else:
setattr(opt, f.name, f.type(getattr(opt, f.name)))
# create stat
stat = types.Statistic(meaninertia=_create_array([mjm.stat.meaninertia], types.array("*", float), {"*": 1}))
# create model
m = types.Model(**{f.name: getattr(mjm, f.name, None) for f in dataclasses.fields(types.Model)})
m.opt = opt
m.stat = stat
m.nv_pad = _get_padded_sizes(
mjm.nv, 0, is_sparse(mjm), types.TILE_SIZE_JTDAJ_SPARSE if is_sparse(mjm) else types.TILE_SIZE_JTDAJ_DENSE
)[1]
m.nacttrnbody = (mjm.actuator_trntype == mujoco.mjtTrn.mjTRN_BODY).sum()
m.nsensortaxel = mjm.mesh_vertnum[mjm.sensor_objid[mjm.sensor_type == mujoco.mjtSensor.mjSENS_TACTILE]].sum()
m.nsensorcontact = (mjm.sensor_type == mujoco.mjtSensor.mjSENS_CONTACT).sum()
m.nrangefinder = (mjm.sensor_type == mujoco.mjtSensor.mjSENS_RANGEFINDER).sum()
m.nmaxcondim = np.concatenate(([0], mjm.geom_condim, mjm.pair_dim)).max()
m.nmaxpyramid = np.maximum(1, 2 * (m.nmaxcondim - 1))
m.has_sdf_geom = (mjm.geom_type == mujoco.mjtGeom.mjGEOM_SDF).any()
m.block_dim = types.BlockDim()
m.is_sparse = is_sparse(mjm)
m.has_fluid = mjm.opt.wind.any() or mjm.opt.density > 0 or mjm.opt.viscosity > 0
# body ids grouped by tree level (depth-based traversal)
bodies, body_depth = {}, np.zeros(mjm.nbody, dtype=int) - 1
for i in range(mjm.nbody):
body_depth[i] = body_depth[mjm.body_parentid[i]] + 1
bodies.setdefault(body_depth[i], []).append(i)
m.body_tree = tuple(wp.array(bodies[i], dtype=int) for i in sorted(bodies))
# branch-based traversal data
children_count = np.bincount(mjm.body_parentid[1:], minlength=mjm.nbody)
ancestor_chain = lambda b: ancestor_chain(mjm.body_parentid[b]) + [b] if b else []
branches = [ancestor_chain(l) for l in np.where(children_count[1:] == 0)[0] + 1]
m.nbranch = len(branches)
body_branches = []
body_branch_start = []
offset = 0
for branch in branches:
body_branches.extend(branch)
body_branch_start.append(offset)
offset += len(branch)
body_branch_start.append(offset)
m.body_branches = np.array(body_branches, dtype=int)
m.body_branch_start = np.array(body_branch_start, dtype=int)
m.mocap_bodyid = np.arange(mjm.nbody)[mjm.body_mocapid >= 0]
m.mocap_bodyid = m.mocap_bodyid[mjm.body_mocapid[mjm.body_mocapid >= 0].argsort()]
m.body_fluid_ellipsoid = np.zeros(mjm.nbody, dtype=bool)
m.body_fluid_ellipsoid[mjm.geom_bodyid[mjm.geom_fluid.reshape(mjm.ngeom, mujoco.mjNFLUID)[:, 0] > 0]] = True
jnt_limited_slide_hinge = mjm.jnt_limited & np.isin(mjm.jnt_type, (mujoco.mjtJoint.mjJNT_SLIDE, mujoco.mjtJoint.mjJNT_HINGE))
m.jnt_limited_slide_hinge_adr = np.nonzero(jnt_limited_slide_hinge)[0]
m.jnt_limited_ball_adr = np.nonzero(mjm.jnt_limited & (mjm.jnt_type == mujoco.mjtJoint.mjJNT_BALL))[0]
m.dof_tri_row, m.dof_tri_col = np.tril_indices(mjm.nv)
# precalculated geom pairs
filterparent = not (mjm.opt.disableflags & types.DisableBit.FILTERPARENT)
geom1, geom2 = np.triu_indices(mjm.ngeom, k=1)
m.nxn_geom_pair = np.stack((geom1, geom2), axis=1)
bodyid1 = mjm.geom_bodyid[geom1]
bodyid2 = mjm.geom_bodyid[geom2]
contype1 = mjm.geom_contype[geom1]
contype2 = mjm.geom_contype[geom2]
conaffinity1 = mjm.geom_conaffinity[geom1]
conaffinity2 = mjm.geom_conaffinity[geom2]
weldid1 = mjm.body_weldid[bodyid1]
weldid2 = mjm.body_weldid[bodyid2]
weld_parentid1 = mjm.body_weldid[mjm.body_parentid[weldid1]]
weld_parentid2 = mjm.body_weldid[mjm.body_parentid[weldid2]]
self_collision = weldid1 == weldid2
parent_child_collision = (
filterparent & (weldid1 != 0) & (weldid2 != 0) & ((weldid1 == weld_parentid2) | (weldid2 == weld_parentid1))
)
mask = np.array((contype1 & conaffinity2) | (contype2 & conaffinity1), dtype=bool)
exclude = np.isin((bodyid1 << 16) + bodyid2, mjm.exclude_signature)
nxn_pairid_contact = -1 * np.ones(len(geom1), dtype=int)
nxn_pairid_contact[~(mask & ~self_collision & ~parent_child_collision & ~exclude)] = -2
# contact pairs
def upper_tri_index(n, i, j):
i, j = (j, i) if j < i else (i, j)
return (i * (2 * n - i - 3)) // 2 + j - 1
for i in range(mjm.npair):
nxn_pairid_contact[upper_tri_index(mjm.ngeom, mjm.pair_geom1[i], mjm.pair_geom2[i])] = i
sensor_collision_adr = np.nonzero(is_collision_sensor)[0]
collision_sensor_adr = np.full(mjm.nsensor, -1)
collision_sensor_adr[sensor_collision_adr] = np.arange(len(sensor_collision_adr))
nxn_pairid_collision = -1 * np.ones(len(geom1), dtype=int)
pairids = []
sensor_collision_start_adr = []
for i in range(sensor_collision_adr.size):
sensorid = sensor_collision_adr[i]
objtype = mjm.sensor_objtype[sensorid]
objid = mjm.sensor_objid[sensorid]
reftype = mjm.sensor_reftype[sensorid]
refid = mjm.sensor_refid[sensorid]
# get lists of geoms to collide
if objtype == types.ObjType.BODY:
n1 = mjm.body_geomnum[objid]
id1 = mjm.body_geomadr[objid]
else:
n1 = 1
id1 = objid
if reftype == types.ObjType.BODY:
n2 = mjm.body_geomnum[refid]
id2 = mjm.body_geomadr[refid]
else:
n2 = 1
id2 = refid
# collide all pairs
for geom1id in range(id1, id1 + n1):
for geom2id in range(id2, id2 + n2):
pairid = upper_tri_index(mjm.ngeom, geom1id, geom2id)
if pairid in pairids:
sensor_collision_start_adr.append(nxn_pairid_collision[pairid])
else:
npairids = len(pairids)
nxn_pairid_collision[pairid] = npairids
sensor_collision_start_adr.append(npairids)
pairids.append(pairid)
m.nsensorcollision = (nxn_pairid_collision >= 0).sum()
m.sensor_collision_start_adr = np.array(sensor_collision_start_adr)
nxn_include = (nxn_pairid_contact > -2) | (nxn_pairid_collision >= 0)
if nxn_include.sum() < 250_000:
opt.broadphase = types.BroadphaseType.NXN
elif mjm.ngeom < 1000:
opt.broadphase = types.BroadphaseType.SAP_TILE
else:
opt.broadphase = types.BroadphaseType.SAP_SEGMENTED
m.nxn_geom_pair_filtered = m.nxn_geom_pair[nxn_include]
m.nxn_pairid = np.hstack([nxn_pairid_contact.reshape((-1, 1)), nxn_pairid_collision.reshape((-1, 1))])
m.nxn_pairid_filtered = m.nxn_pairid[nxn_include]
# count contact pair types
def geom_trid_index(i, j):
i, j = (j, i) if j < i else (i, j)
return (i * (2 * len(types.GeomType) - i - 1)) // 2 + j
m.geom_pair_type_count = tuple(
np.bincount(
[geom_trid_index(mjm.geom_type[geom1[i]], mjm.geom_type[geom2[i]]) for i in np.arange(len(geom1)) if nxn_include[i]],
minlength=len(types.GeomType) * (len(types.GeomType) + 1) // 2,
)
)
m.nmaxpolygon = np.append(mjm.mesh_polyvertnum, 0).max()
m.nmaxmeshdeg = np.append(mjm.mesh_polymapnum, 0).max()
# filter plugins for only geom plugins, drop the rest
m.plugin, m.plugin_attr = [], []
m.geom_plugin_index = np.full_like(mjm.geom_type, -1)
for i in range(len(mjm.geom_plugin)):
if mjm.geom_plugin[i] == -1:
continue
p = mjm.geom_plugin[i]
m.geom_plugin_index[i] = len(m.plugin)
m.plugin.append(mjm.plugin[p])
start = mjm.plugin_attradr[p]
end = mjm.plugin_attradr[p + 1] if p + 1 < mjm.nplugin else len(mjm.plugin_attr)
values = mjm.plugin_attr[start:end]
attr_values = []
current = []
for v in values:
if v == 0:
if current:
s = "".join(chr(int(x)) for x in current)
attr_values.append(float(s))
current = []
else:
current.append(v)
# Pad with zeros if less than 3
attr_values += [0.0] * (3 - len(attr_values))
m.plugin_attr.append(attr_values[:3])
# equality constraint addresses
m.eq_connect_adr = np.nonzero(mjm.eq_type == types.EqType.CONNECT)[0]
m.eq_wld_adr = np.nonzero(mjm.eq_type == types.EqType.WELD)[0]
m.eq_jnt_adr = np.nonzero(mjm.eq_type == types.EqType.JOINT)[0]
m.eq_ten_adr = np.nonzero(mjm.eq_type == types.EqType.TENDON)[0]
m.eq_flex_adr = np.nonzero(mjm.eq_type == types.EqType.FLEX)[0]
# fixed tendon
m.tendon_jnt_adr, m.wrap_jnt_adr = [], []
for i in range(mjm.ntendon):
adr = mjm.tendon_adr[i]
if mjm.wrap_type[adr] == mujoco.mjtWrap.mjWRAP_JOINT:
tendon_num = mjm.tendon_num[i]
for j in range(tendon_num):
m.tendon_jnt_adr.append(i)
m.wrap_jnt_adr.append(adr + j)
# spatial tendon
m.tendon_site_pair_adr, m.tendon_geom_adr = [], []
m.ten_wrapadr_site, m.ten_wrapnum_site = [0], []
for i, tendon_num in enumerate(mjm.tendon_num):
adr = mjm.tendon_adr[i]
# sites
if (mjm.wrap_type[adr : adr + tendon_num] == mujoco.mjtWrap.mjWRAP_SITE).all():
if i < mjm.ntendon:
m.ten_wrapadr_site.append(m.ten_wrapadr_site[-1] + tendon_num)
m.ten_wrapnum_site.append(tendon_num)
else:
if i < mjm.ntendon:
m.ten_wrapadr_site.append(m.ten_wrapadr_site[-1])
m.ten_wrapnum_site.append(0)
# geoms
for j in range(tendon_num):
wrap_type = mjm.wrap_type[adr + j]
if j < tendon_num - 1:
next_wrap_type = mjm.wrap_type[adr + j + 1]
if wrap_type == mujoco.mjtWrap.mjWRAP_SITE and next_wrap_type == mujoco.mjtWrap.mjWRAP_SITE:
m.tendon_site_pair_adr.append(i)
if wrap_type == mujoco.mjtWrap.mjWRAP_SPHERE or wrap_type == mujoco.mjtWrap.mjWRAP_CYLINDER:
m.tendon_geom_adr.append(i)
m.tendon_limited_adr = np.nonzero(mjm.tendon_limited)[0]
m.wrap_site_adr = np.nonzero(mjm.wrap_type == mujoco.mjtWrap.mjWRAP_SITE)[0]
m.wrap_site_pair_adr = np.setdiff1d(m.wrap_site_adr[np.nonzero(np.diff(m.wrap_site_adr) == 1)[0]], mjm.tendon_adr[1:] - 1)
m.wrap_geom_adr = np.nonzero(np.isin(mjm.wrap_type, [mujoco.mjtWrap.mjWRAP_SPHERE, mujoco.mjtWrap.mjWRAP_CYLINDER]))[0]
# pulley scaling
m.wrap_pulley_scale = np.ones(mjm.nwrap, dtype=float)
pulley_adr = np.nonzero(mjm.wrap_type == mujoco.mjtWrap.mjWRAP_PULLEY)[0]
for tadr, tnum in zip(mjm.tendon_adr, mjm.tendon_num):
for padr in pulley_adr:
if tadr <= padr < tadr + tnum:
m.wrap_pulley_scale[padr : tadr + tnum] = 1.0 / mjm.wrap_prm[padr]
m.actuator_trntype_body_adr = np.nonzero(mjm.actuator_trntype == mujoco.mjtTrn.mjTRN_BODY)[0]
# sensor addresses
m.sensor_pos_adr = np.nonzero(
(mjm.sensor_needstage == mujoco.mjtStage.mjSTAGE_POS)
& (mjm.sensor_type != mujoco.mjtSensor.mjSENS_JOINTLIMITPOS)
& (mjm.sensor_type != mujoco.mjtSensor.mjSENS_TENDONLIMITPOS)
)[0]
m.sensor_limitpos_adr = np.nonzero(
(mjm.sensor_type == mujoco.mjtSensor.mjSENS_JOINTLIMITPOS) | (mjm.sensor_type == mujoco.mjtSensor.mjSENS_TENDONLIMITPOS)
)[0]
m.sensor_vel_adr = np.nonzero(
(mjm.sensor_needstage == mujoco.mjtStage.mjSTAGE_VEL)
& (mjm.sensor_type != mujoco.mjtSensor.mjSENS_JOINTLIMITVEL)
& (mjm.sensor_type != mujoco.mjtSensor.mjSENS_TENDONLIMITVEL)
)[0]
m.sensor_limitvel_adr = np.nonzero(
(mjm.sensor_type == mujoco.mjtSensor.mjSENS_JOINTLIMITVEL) | (mjm.sensor_type == mujoco.mjtSensor.mjSENS_TENDONLIMITVEL)
)[0]
m.sensor_acc_adr = np.nonzero(
(mjm.sensor_needstage == mujoco.mjtStage.mjSTAGE_ACC)
& (
(mjm.sensor_type != mujoco.mjtSensor.mjSENS_TOUCH)
| (mjm.sensor_type != mujoco.mjtSensor.mjSENS_JOINTLIMITFRC)
| (mjm.sensor_type != mujoco.mjtSensor.mjSENS_TENDONLIMITFRC)
| (mjm.sensor_type != mujoco.mjtSensor.mjSENS_TENDONACTFRC)
)
)[0]
m.sensor_rangefinder_adr = np.nonzero(mjm.sensor_type == mujoco.mjtSensor.mjSENS_RANGEFINDER)[0]
m.rangefinder_sensor_adr = np.full(mjm.nsensor, -1)
m.rangefinder_sensor_adr[m.sensor_rangefinder_adr] = np.arange(len(m.sensor_rangefinder_adr))
m.collision_sensor_adr = np.full(mjm.nsensor, -1)
m.collision_sensor_adr[sensor_collision_adr] = np.arange(len(sensor_collision_adr))
m.sensor_touch_adr = np.nonzero(mjm.sensor_type == mujoco.mjtSensor.mjSENS_TOUCH)[0]
limitfrc_sensors = (mujoco.mjtSensor.mjSENS_JOINTLIMITFRC, mujoco.mjtSensor.mjSENS_TENDONLIMITFRC)
m.sensor_limitfrc_adr = np.nonzero(np.isin(mjm.sensor_type, limitfrc_sensors))[0]
m.sensor_e_potential = (mjm.sensor_type == mujoco.mjtSensor.mjSENS_E_POTENTIAL).any()
m.sensor_e_kinetic = (mjm.sensor_type == mujoco.mjtSensor.mjSENS_E_KINETIC).any()
m.sensor_tendonactfrc_adr = np.nonzero(mjm.sensor_type == mujoco.mjtSensor.mjSENS_TENDONACTFRC)[0]
subtreevel_sensors = (mujoco.mjtSensor.mjSENS_SUBTREELINVEL, mujoco.mjtSensor.mjSENS_SUBTREEANGMOM)
m.sensor_subtree_vel = np.isin(mjm.sensor_type, subtreevel_sensors).any()
m.sensor_contact_adr = np.nonzero(mjm.sensor_type == mujoco.mjtSensor.mjSENS_CONTACT)[0]
m.sensor_adr_to_contact_adr = np.clip(np.cumsum(mjm.sensor_type == mujoco.mjtSensor.mjSENS_CONTACT) - 1, a_min=0, a_max=None)
m.sensor_rne_postconstraint = np.isin(
mjm.sensor_type,
[
mujoco.mjtSensor.mjSENS_ACCELEROMETER,
mujoco.mjtSensor.mjSENS_FORCE,
mujoco.mjtSensor.mjSENS_TORQUE,
mujoco.mjtSensor.mjSENS_FRAMELINACC,
mujoco.mjtSensor.mjSENS_FRAMEANGACC,
],
).any()
m.sensor_rangefinder_bodyid = mjm.site_bodyid[mjm.sensor_objid[mjm.sensor_type == mujoco.mjtSensor.mjSENS_RANGEFINDER]]
m.taxel_vertadr = [
j + mjm.mesh_vertadr[mjm.sensor_objid[i]]
for i in range(mjm.nsensor)
if mjm.sensor_type[i] == mujoco.mjtSensor.mjSENS_TACTILE
for j in range(mjm.mesh_vertnum[mjm.sensor_objid[i]])
]
m.taxel_sensorid = [
i
for i in range(mjm.nsensor)
if mjm.sensor_type[i] == mujoco.mjtSensor.mjSENS_TACTILE
for j in range(mjm.mesh_vertnum[mjm.sensor_objid[i]])
]
# qM_tiles records the block diagonal structure of qM
tile_corners = [i for i in range(mjm.nv) if mjm.dof_parentid[i] == -1]
tiles = {}
for i in range(len(tile_corners)):
tile_beg = tile_corners[i]
tile_end = mjm.nv if i == len(tile_corners) - 1 else tile_corners[i + 1]
tiles.setdefault(tile_end - tile_beg, []).append(tile_beg)
m.qM_tiles = tuple(types.TileSet(adr=wp.array(tiles[sz], dtype=int), size=sz) for sz in sorted(tiles.keys()))
# qLD_updates has dof tree ordering of qLD updates for sparse factor m
qLD_updates, dof_depth = {}, np.zeros(mjm.nv, dtype=int) - 1
for k in range(mjm.nv):
# skip diagonal rows
if mjm.M_rownnz[k] == 1:
continue
dof_depth[k] = dof_depth[mjm.dof_parentid[k]] + 1
i = mjm.dof_parentid[k]
diag_k = mjm.M_rowadr[k] + mjm.M_rownnz[k] - 1
Madr_ki = diag_k - 1
while i > -1:
qLD_updates.setdefault(dof_depth[i], []).append((i, k, Madr_ki))
i = mjm.dof_parentid[i]
Madr_ki -= 1
m.qLD_updates = tuple(wp.array(qLD_updates[i], dtype=wp.vec3i) for i in sorted(qLD_updates))
# indices for sparse qM_fullm (used in solver)
m.qM_fullm_i, m.qM_fullm_j = [], []
for i in range(mjm.nv):
j = i
while j > -1:
m.qM_fullm_i.append(i)
m.qM_fullm_j.append(j)
j = mjm.dof_parentid[j]
# Gather-based sparse mul_m: for each row, all (col, madr) including diagonal
row_elements = [[] for _ in range(mjm.nv)]
# Add diagonal
for i in range(mjm.nv):
row_elements[i].append((i, mjm.dof_Madr[i]))
# Add off-diagonals: ancestors (lower) and descendants (upper)
for i in range(mjm.nv):
madr_ij, j = mjm.dof_Madr[i], i
while True:
madr_ij, j = madr_ij + 1, mjm.dof_parentid[j]
if j == -1:
break
row_elements[i].append((j, madr_ij)) # row i gathers M[i,j] * vec[j]
row_elements[j].append((i, madr_ij)) # row j gathers M[j,i] * vec[i]
# Flatten into CSR-like arrays
m.qM_mulm_rowadr = [0]
m.qM_mulm_col = []
m.qM_mulm_madr = []
for i in range(mjm.nv):
for col, madr in row_elements[i]:
m.qM_mulm_col.append(col)
m.qM_mulm_madr.append(madr)
m.qM_mulm_rowadr.append(len(m.qM_mulm_col))
# TODO(team): remove after mjwarp depends on mujoco > 3.4.0 in pyproject.toml
if BLEEDING_EDGE_MUJOCO:
m.flexedge_J_rownnz = mjm.flexedge_J_rownnz
m.flexedge_J_rowadr = mjm.flexedge_J_rowadr
m.flexedge_J_colind = mjm.flexedge_J_colind.reshape(-1)
else:
mjd = mujoco.MjData(mjm)
mujoco.mj_forward(mjm, mjd)
m.flexedge_J_rownnz = mjd.flexedge_J_rownnz
m.flexedge_J_rowadr = mjd.flexedge_J_rowadr
m.flexedge_J_colind = mjd.flexedge_J_colind.reshape(-1)
# place m on device
sizes = dict({"*": 1}, **{f.name: getattr(m, f.name) for f in dataclasses.fields(types.Model) if f.type is int})
for f in dataclasses.fields(types.Model):
if isinstance(f.type, wp.array):
setattr(m, f.name, _create_array(getattr(m, f.name), f.type, sizes))
return m
def _get_padded_sizes(nv: int, njmax: int, is_sparse: bool, tile_size: int):
# if dense - we just pad to the next multiple of 4 for nv, to get the fast load path.
# we pad to the next multiple of tile_size for njmax to avoid out of bounds accesses.
# if sparse - we pad to the next multiple of tile_size for njmax, and nv.
def round_up(x, multiple):
return ((x + multiple - 1) // multiple) * multiple
njmax_padded = round_up(njmax, tile_size)
nv_padded = round_up(nv, tile_size) if (is_sparse or nv > 32) else round_up(nv, 4)
return njmax_padded, nv_padded
def _default_nconmax(mjm: mujoco.MjModel, mjd: Optional[mujoco.MjData] = None) -> int:
"""Returns a default guess for an ideal nconmax given a Model and optional Data.
This guess is based off a very simple heuristic, and may need to be manually raised if MJWarp
reports ncon overflow, or lowered in order to get the very best performance.
"""
valid_sizes = (2 + (np.arange(19) % 2)) * (2 ** (np.arange(19) // 2 + 3)) # 16, 24, 32, 48, ... 8192
has_sdf = (mjm.geom_type == mujoco.mjtGeom.mjGEOM_SDF).any()
has_flex = mjm.nflex > 0
nconmax = max(mjm.nv * 0.35 * (mjm.nhfield > 0) * 10 + 45, 256 * has_flex, 64 * has_sdf, mjd.ncon if mjd else 0)
return int(valid_sizes[np.searchsorted(valid_sizes, nconmax)])
def _default_njmax(mjm: mujoco.MjModel, mjd: Optional[mujoco.MjData] = None) -> int:
"""Returns a default guess for an ideal njmax given a Model and optional Data.
This guess is based off a very simple heuristic, and may need to be manually raised if MJWarp
reports ncon overflow, or lowered in order to get the very best performance.
"""
valid_sizes = (2 + (np.arange(19) % 2)) * (2 ** (np.arange(19) // 2 + 3)) # 16, 24, 32, 48, ... 8192
has_sdf = (mjm.geom_type == mujoco.mjtGeom.mjGEOM_SDF).any()
has_flex = mjm.nflex > 0
njmax = max(mjm.nv * 2.26 * (mjm.nhfield > 0) * 18 + 53, 512 * has_flex, 256 * has_sdf, mjd.nefc if mjd else 0)
return int(valid_sizes[np.searchsorted(valid_sizes, njmax)])
[docs]
def make_data(
mjm: mujoco.MjModel,
nworld: int = 1,
nconmax: Optional[int] = None,
nccdmax: Optional[int] = None,
njmax: Optional[int] = None,
naconmax: Optional[int] = None,
naccdmax: Optional[int] = None,
) -> types.Data:
"""Creates a data object on device.
Args:
mjm: The model containing kinematic and dynamic information (host).
nworld: Number of worlds.
nconmax: Number of contacts to allocate per world. Contacts exist in large
heterogeneous arrays: one world may have more than nconmax contacts.
nccdmax: Number of CCD contacts to allocate per world. Same semantics as nconmax.
njmax: Number of constraints to allocate per world. Constraint arrays are
batched by world: no world may have more than njmax constraints.
naconmax: Number of contacts to allocate for all worlds. Overrides nconmax.
naccdmax: Maximum number of CCD contacts. Defaults to naconmax.
Returns:
The data object containing the current state and output arrays (device).
"""
# TODO(team): move nconmax, njmax to Model?
if nconmax is None:
nconmax = _default_nconmax(mjm)
if nconmax < 0:
raise ValueError("nconmax must be >= 0")
if nccdmax is None:
nccdmax = nconmax
elif nccdmax < 0:
raise ValueError("nccdmax must be >= 0")
elif nccdmax > nconmax:
raise ValueError(f"nccdmax ({nccdmax}) must be <= nconmax ({nconmax})")
if njmax is None:
njmax = _default_njmax(mjm)
if njmax < 0:
raise ValueError("njmax must be >= 0")
if nworld < 1:
raise ValueError(f"nworld must be >= 1")
if naconmax is None:
naconmax = nworld * nconmax
elif naconmax < 0:
raise ValueError("naconmax must be >= 0")
if naccdmax is None:
naccdmax = nworld * nccdmax
elif naccdmax < 0:
raise ValueError("naccdmax must be >= 0")
elif naccdmax > naconmax:
raise ValueError(f"naccdmax ({naccdmax}) must be <= naconmax ({naconmax})")
sizes = dict({"*": 1}, **{f.name: getattr(mjm, f.name, None) for f in dataclasses.fields(types.Model) if f.type is int})
sizes["nmaxcondim"] = np.concatenate(([0], mjm.geom_condim, mjm.pair_dim)).max()
sizes["nmaxpyramid"] = np.maximum(1, 2 * (sizes["nmaxcondim"] - 1))
tile_size = types.TILE_SIZE_JTDAJ_SPARSE if is_sparse(mjm) else types.TILE_SIZE_JTDAJ_DENSE
sizes["njmax_pad"], sizes["nv_pad"] = _get_padded_sizes(mjm.nv, njmax, is_sparse(mjm), tile_size)
sizes["nworld"] = nworld
sizes["naconmax"] = naconmax
sizes["njmax"] = njmax
contact = types.Contact(**{f.name: _create_array(None, f.type, sizes) for f in dataclasses.fields(types.Contact)})
efc = types.Constraint(**{f.name: _create_array(None, f.type, sizes) for f in dataclasses.fields(types.Constraint)})
# world body and static geom (attached to the world) poses are precomputed
# this speeds up scenes with many static geoms (e.g. terrains)
# TODO(team): remove this when we introduce dof islands + sleeping
mjd = mujoco.MjData(mjm)
mujoco.mj_kinematics(mjm, mjd)
# mocap
mocap_body = np.nonzero(mjm.body_mocapid >= 0)[0]
mocap_id = mjm.body_mocapid[mocap_body]
d_kwargs = {
"qpos": wp.array(np.tile(mjm.qpos0, nworld), shape=(nworld, mjm.nq), dtype=float),
"contact": contact,
"efc": efc,
"nworld": nworld,
"naconmax": naconmax,
"naccdmax": naccdmax,
"njmax": njmax,
"qM": None,
"qLD": None,
# world body
"xquat": wp.array(np.tile(mjd.xquat, (nworld, 1)), shape=(nworld, mjm.nbody), dtype=wp.quat),
"xmat": wp.array(np.tile(mjd.xmat, (nworld, 1)), shape=(nworld, mjm.nbody), dtype=wp.mat33),
"ximat": wp.array(np.tile(mjd.ximat, (nworld, 1)), shape=(nworld, mjm.nbody), dtype=wp.mat33),
# static geoms
"geom_xpos": wp.array(np.tile(mjd.geom_xpos, (nworld, 1)), shape=(nworld, mjm.ngeom), dtype=wp.vec3),
"geom_xmat": wp.array(np.tile(mjd.geom_xmat, (nworld, 1)), shape=(nworld, mjm.ngeom), dtype=wp.mat33),
# mocap
"mocap_pos": wp.array(np.tile(mjm.body_pos[mocap_body[mocap_id]], (nworld, 1)), shape=(nworld, mjm.nmocap), dtype=wp.vec3),
"mocap_quat": wp.array(
np.tile(mjm.body_quat[mocap_body[mocap_id]], (nworld, 1)), shape=(nworld, mjm.nmocap), dtype=wp.quat
),
# equality constraints
"eq_active": wp.array(np.tile(mjm.eq_active0.astype(bool), (nworld, 1)), shape=(nworld, mjm.neq), dtype=bool),
# flexedge
"flexedge_J": None,
}
for f in dataclasses.fields(types.Data):
if f.name in d_kwargs:
continue
d_kwargs[f.name] = _create_array(None, f.type, sizes)
d = types.Data(**d_kwargs)
if is_sparse(mjm):
d.qM = wp.zeros((nworld, 1, mjm.nM), dtype=float)
d.qLD = wp.zeros((nworld, 1, mjm.nC), dtype=float)
else:
d.qM = wp.zeros((nworld, sizes["nv_pad"], sizes["nv_pad"]), dtype=float)
d.qLD = wp.zeros((nworld, mjm.nv, mjm.nv), dtype=float)
d.flexedge_J = wp.zeros((nworld, 1, mjd.flexedge_J.size), dtype=float)
return d
[docs]
def put_data(
mjm: mujoco.MjModel,
mjd: mujoco.MjData,
nworld: int = 1,
nconmax: Optional[int] = None,
nccdmax: Optional[int] = None,
njmax: Optional[int] = None,
naconmax: Optional[int] = None,
naccdmax: Optional[int] = None,
) -> types.Data:
"""Moves data from host to a device.
Args:
mjm: The model containing kinematic and dynamic information (host).
mjd: The data object containing current state and output arrays (host).
nworld: The number of worlds.
nconmax: Number of contacts to allocate per world. Contacts exist in large
heterogenous arrays: one world may have more than nconmax contacts.
nccdmax: Number of CCD contacts to allocate per world. Same semantics as nconmax.
njmax: Number of constraints to allocate per world. Constraint arrays are
batched by world: no world may have more than njmax constraints.
naconmax: Number of contacts to allocate for all worlds. Overrides nconmax.
naccdmax: Maximum number of CCD contacts. Defaults to naconmax.
Returns:
The data object containing the current state and output arrays (device).
"""
# TODO(team): move nconmax and njmax to Model?
# TODO(team): decide what to do about uninitialized warp-only fields created by put_data
# we need to ensure these are only workspace fields and don't carry state
if nconmax is None:
nconmax = _default_nconmax(mjm, mjd)
if nconmax < 0:
raise ValueError("nconmax must be >= 0")
if nccdmax is None:
nccdmax = nconmax
elif nccdmax < 0:
raise ValueError("nccdmax must be >= 0")
elif nccdmax > nconmax:
raise ValueError(f"nccdmax ({nccdmax}) must be <= nconmax ({nconmax})")
if njmax is None:
njmax = _default_njmax(mjm, mjd)
if njmax < 0:
raise ValueError("njmax must be >= 0")
if nworld < 1:
raise ValueError(f"nworld must be >= 1")
if naconmax is None:
if mjd.ncon > nconmax:
raise ValueError(f"nconmax overflow (nconmax must be >= {mjd.ncon})")
naconmax = nworld * nconmax
elif naconmax < mjd.ncon * nworld:
raise ValueError(f"naconmax overflow (naconmax must be >= {mjd.ncon * nworld})")
if naccdmax is None:
naccdmax = nworld * nccdmax
elif naccdmax < 0:
raise ValueError("naccdmax must be >= 0")
elif naccdmax > naconmax:
raise ValueError(f"naccdmax ({naccdmax}) must be <= naconmax ({naconmax})")
if mjd.nefc > njmax:
raise ValueError(f"njmax overflow (njmax must be >= {mjd.nefc})")
sizes = dict({"*": 1}, **{f.name: getattr(mjm, f.name, None) for f in dataclasses.fields(types.Model) if f.type is int})
sizes["nmaxcondim"] = np.concatenate(([0], mjm.geom_condim, mjm.pair_dim)).max()
sizes["nmaxpyramid"] = np.maximum(1, 2 * (sizes["nmaxcondim"] - 1))
tile_size = types.TILE_SIZE_JTDAJ_SPARSE if is_sparse(mjm) else types.TILE_SIZE_JTDAJ_DENSE
sizes["njmax_pad"], sizes["nv_pad"] = _get_padded_sizes(mjm.nv, njmax, is_sparse(mjm), tile_size)
sizes["nworld"] = nworld
sizes["naconmax"] = naconmax
sizes["njmax"] = njmax
# ensure static geom positions are computed
# TODO: remove once MjData creation semantics are fixed
mujoco.mj_kinematics(mjm, mjd)
# create contact
contact_kwargs = {"efc_address": None, "worldid": None, "type": None, "geomcollisionid": None}
for f in dataclasses.fields(types.Contact):
if f.name in contact_kwargs:
continue
val = getattr(mjd.contact, f.name)
val = np.repeat(val, nworld, axis=0)
width = ((0, naconmax - val.shape[0]),) + ((0, 0),) * (val.ndim - 1)
val = np.pad(val, width)
contact_kwargs[f.name] = _create_array(val, f.type, sizes)
contact = types.Contact(**contact_kwargs)
contact.efc_address = np.zeros((naconmax, sizes["nmaxpyramid"]), dtype=int)
for i in range(mjd.ncon):
efc_address = mjd.contact.efc_address[i]
if efc_address == -1:
continue
condim = mjd.contact.dim[i]
ndim = max(1, 2 * (condim - 1)) if mjm.opt.cone == mujoco.mjtCone.mjCONE_PYRAMIDAL else condim
for j in range(nworld):
contact.efc_address[j * mjd.ncon + i, :ndim] = efc_address + np.arange(ndim)
contact.efc_address = wp.array(contact.efc_address, dtype=int)
contact.worldid = np.pad(np.repeat(np.arange(nworld), mjd.ncon), (0, naconmax - nworld * mjd.ncon))
contact.worldid = wp.array(contact.worldid, dtype=int)
contact.type = wp.ones((naconmax,), dtype=int) # TODO(team): set values
contact.geomcollisionid = wp.empty((naconmax,), dtype=int) # TODO(team): set values
# create efc
efc_kwargs = {"J": None}
for f in dataclasses.fields(types.Constraint):
if f.name in efc_kwargs:
continue
shape = tuple(sizes[dim] if isinstance(dim, str) else dim for dim in f.type.shape)
val = np.zeros(shape, dtype=f.type.dtype)
if f.name in ("type", "id", "pos", "margin", "D", "vel", "aref", "frictionloss", "force"):
val[:, : mjd.nefc] = np.tile(getattr(mjd, "efc_" + f.name), (nworld, 1))
efc_kwargs[f.name] = wp.array(val, dtype=f.type.dtype)
efc = types.Constraint(**efc_kwargs)
if mujoco.mj_isSparse(mjm):
efc_j = np.zeros((mjd.nefc, mjm.nv))
mujoco.mju_sparse2dense(efc_j, mjd.efc_J, mjd.efc_J_rownnz, mjd.efc_J_rowadr, mjd.efc_J_colind)
else:
efc_j = mjd.efc_J.reshape((mjd.nefc, mjm.nv))
efc.J = np.zeros((nworld, sizes["njmax_pad"], sizes["nv_pad"]), dtype=f.type.dtype)
efc.J[:, : mjd.nefc, : mjm.nv] = np.tile(efc_j, (nworld, 1, 1))
efc.J = wp.array(efc.J, dtype=float)
# create data
d_kwargs = {
"contact": contact,
"efc": efc,
"nworld": nworld,
"naconmax": naconmax,
"naccdmax": naccdmax,
"njmax": njmax,
# fields set after initialization:
"solver_niter": None,
"qM": None,
"qLD": None,
"ten_J": None,
"actuator_moment": None,
"flexedge_J": None,
"nacon": None,
}
for f in dataclasses.fields(types.Data):
if f.name in d_kwargs:
continue
val = getattr(mjd, f.name, None)
if val is not None:
shape = val.shape if hasattr(val, "shape") else ()
val = np.full((nworld,) + shape, val)
d_kwargs[f.name] = _create_array(val, f.type, sizes)
d = types.Data(**d_kwargs)
d.solver_niter = wp.full((nworld,), mjd.solver_niter[0], dtype=int)
if is_sparse(mjm):
d.qM = wp.array(np.full((nworld, 1, mjm.nM), mjd.qM), dtype=float)
d.qLD = wp.array(np.full((nworld, 1, mjm.nC), mjd.qLD), dtype=float)
else:
qM = np.zeros((mjm.nv, mjm.nv))
mujoco.mj_fullM(mjm, qM, mjd.qM)
qLD = np.linalg.cholesky(qM) if (mjd.qM != 0.0).any() and (mjd.qLD != 0.0).any() else np.zeros((mjm.nv, mjm.nv))
padding = sizes["nv_pad"] - mjm.nv
qM_padded = np.pad(qM, ((0, padding), (0, padding)), mode="constant", constant_values=0.0)
d.qM = wp.array(np.full((nworld, sizes["nv_pad"], sizes["nv_pad"]), qM_padded), dtype=float)
d.qLD = wp.array(np.full((nworld, mjm.nv, mjm.nv), qLD), dtype=float)
d.flexedge_J = wp.array(np.tile(mjd.flexedge_J.reshape(-1), (nworld, 1)).reshape((nworld, 1, -1)), dtype=float)
if mjm.ntendon:
ten_J = np.zeros((mjm.ntendon, mjm.nv))
mujoco.mju_sparse2dense(ten_J, mjd.ten_J.reshape(-1), mjm.ten_J_rownnz, mjm.ten_J_rowadr, mjm.ten_J_colind.reshape(-1))
d.ten_J = wp.array(np.full((nworld, mjm.ntendon, mjm.nv), ten_J), dtype=float)
else:
d.ten_J = wp.array(np.full((nworld, mjm.ntendon, mjm.nv), 0.0), dtype=float)
# TODO(taylorhowell): sparse actuator_moment
actuator_moment = np.zeros((mjm.nu, mjm.nv))
mujoco.mju_sparse2dense(actuator_moment, mjd.actuator_moment, mjd.moment_rownnz, mjd.moment_rowadr, mjd.moment_colind)
d.actuator_moment = wp.array(np.full((nworld, mjm.nu, mjm.nv), actuator_moment), dtype=float)
d.nacon = wp.array([mjd.ncon * nworld], dtype=int)
return d
[docs]
def get_data_into(
result: mujoco.MjData,
mjm: mujoco.MjModel,
d: types.Data,
world_id: int = 0,
):
"""Gets data from a device into an existing mujoco.MjData.
Args:
result: The data object containing the current state and output arrays (host).
mjm: The model containing kinematic and dynamic information (host).
d: The data object containing the current state and output arrays (device).
world_id: The id of the world to get the data from.
"""
# nacon and nefc can overflow. in that case, only pull up to the max contacts and constraints
nacon = min(d.nacon.numpy()[0], d.naconmax)
nefc = min(d.nefc.numpy()[world_id], d.njmax)
ncon_filter = np.zeros_like(d.contact.worldid.numpy(), dtype=bool)
ncon_filter[:nacon] = d.contact.worldid.numpy()[:nacon] == world_id
ncon = ncon_filter.sum()
if ncon != result.ncon or nefc != result.nefc:
# TODO(team): if sparse, set nJ based on sparse efc_J
mujoco._functions._realloc_con_efc(result, ncon=ncon, nefc=nefc, nJ=nefc * mjm.nv)
ne = d.ne.numpy()[world_id]
nf = d.nf.numpy()[world_id]
nl = d.nl.numpy()[world_id]
# efc indexing
# mujoco expects contiguous efc ordering for contacts
# this ordering is not guaranteed with mujoco warp, we enforce order here
if ncon > 0:
efc_idx_efl = np.arange(ne + nf + nl)
contact_dim = d.contact.dim.numpy()[ncon_filter]
contact_efc_address = d.contact.efc_address.numpy()[ncon_filter]
efc_idx_c = []
contact_efc_address_ordered = [ne + nf + nl]
for i in range(ncon):
dim = contact_dim[i]
if mjm.opt.cone == mujoco.mjtCone.mjCONE_PYRAMIDAL:
ndim = np.maximum(1, 2 * (dim - 1))
else:
ndim = dim
efc_idx_c.append(contact_efc_address[i, :ndim])
if i < ncon - 1:
contact_efc_address_ordered.append(contact_efc_address_ordered[-1] + ndim)
efc_idx = np.concatenate((efc_idx_efl, *efc_idx_c))
contact_efc_address_ordered = np.array(contact_efc_address_ordered)
else:
efc_idx = np.array(np.arange(nefc))
contact_efc_address_ordered = np.empty(0)
efc_idx = efc_idx[:nefc] # dont emit indices for overflow constraints
result.solver_niter[0] = d.solver_niter.numpy()[world_id]
result.ncon = ncon
result.ne = ne
result.nf = nf
result.nl = nl
result.time = d.time.numpy()[world_id]
result.energy[:] = d.energy.numpy()[world_id]
result.qpos[:] = d.qpos.numpy()[world_id]
result.qvel[:] = d.qvel.numpy()[world_id]
result.act[:] = d.act.numpy()[world_id]
result.qacc_warmstart[:] = d.qacc_warmstart.numpy()[world_id]
result.ctrl[:] = d.ctrl.numpy()[world_id]
result.qfrc_applied[:] = d.qfrc_applied.numpy()[world_id]
result.xfrc_applied[:] = d.xfrc_applied.numpy()[world_id]
result.eq_active[:] = d.eq_active.numpy()[world_id]
result.mocap_pos[:] = d.mocap_pos.numpy()[world_id]
result.mocap_quat[:] = d.mocap_quat.numpy()[world_id]
result.qacc[:] = d.qacc.numpy()[world_id]
result.act_dot[:] = d.act_dot.numpy()[world_id]
result.xpos[:] = d.xpos.numpy()[world_id]
result.xquat[:] = d.xquat.numpy()[world_id]
result.xmat[:] = d.xmat.numpy()[world_id].reshape((-1, 9))
result.xipos[:] = d.xipos.numpy()[world_id]
result.ximat[:] = d.ximat.numpy()[world_id].reshape((-1, 9))
result.xanchor[:] = d.xanchor.numpy()[world_id]
result.xaxis[:] = d.xaxis.numpy()[world_id]
result.geom_xpos[:] = d.geom_xpos.numpy()[world_id]
result.geom_xmat[:] = d.geom_xmat.numpy()[world_id].reshape((-1, 9))
result.site_xpos[:] = d.site_xpos.numpy()[world_id]
result.site_xmat[:] = d.site_xmat.numpy()[world_id].reshape((-1, 9))
result.cam_xpos[:] = d.cam_xpos.numpy()[world_id]
result.cam_xmat[:] = d.cam_xmat.numpy()[world_id].reshape((-1, 9))
result.light_xpos[:] = d.light_xpos.numpy()[world_id]
result.light_xdir[:] = d.light_xdir.numpy()[world_id]
result.subtree_com[:] = d.subtree_com.numpy()[world_id]
result.cdof[:] = d.cdof.numpy()[world_id]
result.cinert[:] = d.cinert.numpy()[world_id]
result.flexvert_xpos[:] = d.flexvert_xpos.numpy()[world_id]
if mjm.nflexedge > 0:
# TODO(team): remove after mjwarp depends on mujoco > 3.4.0 in pyproject.toml
if not BLEEDING_EDGE_MUJOCO:
m = put_model(mjm)
result.flexedge_J_rownnz[:] = m.flexedge_J_rownnz.numpy()
result.flexedge_J_rowadr[:] = m.flexedge_J_rowadr.numpy()
result.flexedge_J_colind[:, :] = m.flexedge_J_colind.numpy().reshape((mjm.nflexedge, mjm.nv))
mujoco.mju_sparse2dense(
result.flexedge_J,
d.flexedge_J.numpy()[world_id].reshape(-1),
m.flexedge_J_rownnz.numpy(),
m.flexedge_J_rowadr.numpy(),
m.flexedge_J_colind.numpy(),
)
else:
result.flexedge_J[:] = d.flexedge_J.numpy()[world_id].reshape(-1)
result.flexedge_length[:] = d.flexedge_length.numpy()[world_id]
result.flexedge_velocity[:] = d.flexedge_velocity.numpy()[world_id]
result.actuator_length[:] = d.actuator_length.numpy()[world_id]
actuator_moment = d.actuator_moment.numpy()[world_id]
mujoco.mju_dense2sparse(
result.actuator_moment, actuator_moment, result.moment_rownnz, result.moment_rowadr, result.moment_colind
)
result.crb[:] = d.crb.numpy()[world_id]
result.qLDiagInv[:] = d.qLDiagInv.numpy()[world_id]
result.ten_velocity[:] = d.ten_velocity.numpy()[world_id]
result.actuator_velocity[:] = d.actuator_velocity.numpy()[world_id]
result.cvel[:] = d.cvel.numpy()[world_id]
result.cdof_dot[:] = d.cdof_dot.numpy()[world_id]
result.qfrc_bias[:] = d.qfrc_bias.numpy()[world_id]
result.qfrc_spring[:] = d.qfrc_spring.numpy()[world_id]
result.qfrc_damper[:] = d.qfrc_damper.numpy()[world_id]
result.qfrc_gravcomp[:] = d.qfrc_gravcomp.numpy()[world_id]
result.qfrc_fluid[:] = d.qfrc_fluid.numpy()[world_id]
result.qfrc_passive[:] = d.qfrc_passive.numpy()[world_id]
result.subtree_linvel[:] = d.subtree_linvel.numpy()[world_id]
result.subtree_angmom[:] = d.subtree_angmom.numpy()[world_id]
result.actuator_force[:] = d.actuator_force.numpy()[world_id]
result.qfrc_actuator[:] = d.qfrc_actuator.numpy()[world_id]
result.qfrc_smooth[:] = d.qfrc_smooth.numpy()[world_id]
result.qacc_smooth[:] = d.qacc_smooth.numpy()[world_id]
result.qfrc_constraint[:] = d.qfrc_constraint.numpy()[world_id]
result.qfrc_inverse[:] = d.qfrc_inverse.numpy()[world_id]
# contact
result.contact.dist[:ncon] = d.contact.dist.numpy()[ncon_filter]
result.contact.pos[:ncon] = d.contact.pos.numpy()[ncon_filter]
result.contact.frame[:ncon] = d.contact.frame.numpy()[ncon_filter].reshape((-1, 9))
result.contact.includemargin[:ncon] = d.contact.includemargin.numpy()[ncon_filter]
result.contact.friction[:ncon] = d.contact.friction.numpy()[ncon_filter]
result.contact.solref[:ncon] = d.contact.solref.numpy()[ncon_filter]
result.contact.solreffriction[:ncon] = d.contact.solreffriction.numpy()[ncon_filter]
result.contact.solimp[:ncon] = d.contact.solimp.numpy()[ncon_filter]
result.contact.dim[:ncon] = d.contact.dim.numpy()[ncon_filter]
result.contact.geom[:ncon] = d.contact.geom.numpy()[ncon_filter]
result.contact.efc_address[:ncon] = contact_efc_address_ordered[:ncon]
if is_sparse(mjm):
result.qM[:] = d.qM.numpy()[world_id, 0]
result.qLD[:] = d.qLD.numpy()[world_id, 0]
else:
qM = d.qM.numpy()[world_id]
adr = 0
for i in range(mjm.nv):
j = i
while j >= 0:
result.qM[adr] = qM[i, j]
j = mjm.dof_parentid[j]
adr += 1
mujoco.mj_factorM(mjm, result)
if nefc > 0:
if mujoco.mj_isSparse(mjm):
efc_J = d.efc.J.numpy()[world_id, efc_idx, : mjm.nv]
mujoco.mju_dense2sparse(result.efc_J, efc_J, result.efc_J_rownnz, result.efc_J_rowadr, result.efc_J_colind)
else:
result.efc_J[: nefc * mjm.nv] = d.efc.J.numpy()[world_id, :nefc, : mjm.nv].flatten()
# efc
result.efc_type[:] = d.efc.type.numpy()[world_id, efc_idx]
result.efc_id[:] = d.efc.id.numpy()[world_id, efc_idx]
result.efc_pos[:] = d.efc.pos.numpy()[world_id, efc_idx]
result.efc_margin[:] = d.efc.margin.numpy()[world_id, efc_idx]
result.efc_D[:] = d.efc.D.numpy()[world_id, efc_idx]
result.efc_vel[:] = d.efc.vel.numpy()[world_id, efc_idx]
result.efc_aref[:] = d.efc.aref.numpy()[world_id, efc_idx]
result.efc_frictionloss[:] = d.efc.frictionloss.numpy()[world_id, efc_idx]
result.efc_state[:] = d.efc.state.numpy()[world_id, efc_idx]
result.efc_force[:] = d.efc.force.numpy()[world_id, efc_idx]
# rne_postconstraint
result.cacc[:] = d.cacc.numpy()[world_id]
result.cfrc_int[:] = d.cfrc_int.numpy()[world_id]
result.cfrc_ext[:] = d.cfrc_ext.numpy()[world_id]
# tendon
result.ten_length[:] = d.ten_length.numpy()[world_id]
# TODO(team): remove after mjwarp depends on mujoco > 3.4.0 in pyproject.toml
if BLEEDING_EDGE_MUJOCO:
ten_J = d.ten_J.numpy()[world_id]
mujoco.mju_dense2sparse(
result.ten_J,
ten_J,
mjm.ten_J_rownnz,
mjm.ten_J_rowadr,
mjm.ten_J_colind,
)
else:
result.ten_J[:] = d.ten_J.numpy()[world_id]
result.ten_wrapadr[:] = d.ten_wrapadr.numpy()[world_id]
result.ten_wrapnum[:] = d.ten_wrapnum.numpy()[world_id]
result.wrap_obj[:] = d.wrap_obj.numpy()[world_id]
result.wrap_xpos[:] = d.wrap_xpos.numpy()[world_id]
# sensors
result.sensordata[:] = d.sensordata.numpy()[world_id]
[docs]
def reset_data(m: types.Model, d: types.Data, reset: Optional[wp.array] = None):
"""Clear data, set defaults; optionally by world.
Args:
m: The model containing kinematic and dynamic information (device).
d: The data object containing the current state and output arrays (device).
reset: Per-world bitmask. Reset if True.
"""
@wp.kernel(module="unique", enable_backward=False)
def reset_xfrc_applied(reset_in: wp.array(dtype=bool), xfrc_applied_out: wp.array2d(dtype=wp.spatial_vector)):
worldid, bodyid, elemid = wp.tid()
if wp.static(reset is not None):
if not reset_in[worldid]:
return
xfrc_applied_out[worldid, bodyid][elemid] = 0.0
@wp.kernel(module="unique", enable_backward=False)
def reset_qM(reset_in: wp.array(dtype=bool), qM_out: wp.array3d(dtype=float)):
worldid, elemid1, elemid2 = wp.tid()
if wp.static(reset is not None):
if not reset_in[worldid]:
return
qM_out[worldid, elemid1, elemid2] = 0.0
@wp.kernel(module="unique", enable_backward=False)
def reset_nworld(
# Model:
nq: int,
nv: int,
nu: int,
na: int,
neq: int,
nsensordata: int,
qpos0: wp.array2d(dtype=float),
eq_active0: wp.array(dtype=bool),
# Data in:
nworld_in: int,
# In:
reset_in: wp.array(dtype=bool),
# Data out:
solver_niter_out: wp.array(dtype=int),
ne_out: wp.array(dtype=int),
nf_out: wp.array(dtype=int),
nl_out: wp.array(dtype=int),
nefc_out: wp.array(dtype=int),
time_out: wp.array(dtype=float),
energy_out: wp.array(dtype=wp.vec2),
qpos_out: wp.array2d(dtype=float),
qvel_out: wp.array2d(dtype=float),
act_out: wp.array2d(dtype=float),
qacc_warmstart_out: wp.array2d(dtype=float),
ctrl_out: wp.array2d(dtype=float),
qfrc_applied_out: wp.array2d(dtype=float),
eq_active_out: wp.array2d(dtype=bool),
qacc_out: wp.array2d(dtype=float),
act_dot_out: wp.array2d(dtype=float),
sensordata_out: wp.array2d(dtype=float),
nacon_out: wp.array(dtype=int),
):
worldid = wp.tid()
if wp.static(reset is not None):
if not reset_in[worldid]:
return
solver_niter_out[worldid] = 0
if worldid == 0:
nacon_out[0] = 0
ne_out[worldid] = 0
nf_out[worldid] = 0
nl_out[worldid] = 0
nefc_out[worldid] = 0
time_out[worldid] = 0.0
energy_out[worldid] = wp.vec2(0.0, 0.0)
qpos0_id = worldid % qpos0.shape[0]
for i in range(nq):
qpos_out[worldid, i] = qpos0[qpos0_id, i]
if i < nv:
qvel_out[worldid, i] = 0.0
qacc_warmstart_out[worldid, i] = 0.0
qfrc_applied_out[worldid, i] = 0.0
qacc_out[worldid, i] = 0.0
for i in range(nu):
ctrl_out[worldid, i] = 0.0
if i < na:
act_out[worldid, i] = 0.0
act_dot_out[worldid, i] = 0.0
for i in range(neq):
eq_active_out[worldid, i] = eq_active0[i]
for i in range(nsensordata):
sensordata_out[worldid, i] = 0.0
@wp.kernel(module="unique", enable_backward=False)
def reset_mocap(
# Model:
body_mocapid: wp.array(dtype=int),
body_pos: wp.array2d(dtype=wp.vec3),
body_quat: wp.array2d(dtype=wp.quat),
# In:
reset_in: wp.array(dtype=bool),
# Data out:
mocap_pos_out: wp.array2d(dtype=wp.vec3),
mocap_quat_out: wp.array2d(dtype=wp.quat),
):
worldid, bodyid = wp.tid()
if wp.static(reset is not None):
if not reset_in[worldid]:
return
mocapid = body_mocapid[bodyid]
if mocapid >= 0:
mocap_pos_out[worldid, mocapid] = body_pos[worldid, bodyid]
mocap_quat_out[worldid, mocapid] = body_quat[worldid, bodyid]
@wp.kernel(module="unique", enable_backward=False)
def reset_contact(
# Data in:
nacon_in: wp.array(dtype=int),
# In:
reset_in: wp.array(dtype=bool),
nefcaddress: int,
# Data out:
contact_dist_out: wp.array(dtype=float),
contact_pos_out: wp.array(dtype=wp.vec3),
contact_frame_out: wp.array(dtype=wp.mat33),
contact_includemargin_out: wp.array(dtype=float),
contact_friction_out: wp.array(dtype=types.vec5),
contact_solref_out: wp.array(dtype=wp.vec2),
contact_solreffriction_out: wp.array(dtype=wp.vec2),
contact_solimp_out: wp.array(dtype=types.vec5),
contact_dim_out: wp.array(dtype=int),
contact_geom_out: wp.array(dtype=wp.vec2i),
contact_efc_address_out: wp.array2d(dtype=int),
contact_worldid_out: wp.array(dtype=int),
contact_type_out: wp.array(dtype=int),
contact_geomcollisionid_out: wp.array(dtype=int),
):
conid = wp.tid()
if conid >= nacon_in[0]:
return
worldid = contact_worldid_out[conid]
if wp.static(reset is not None):
if worldid >= 0:
if not reset_in[worldid]:
return
contact_dist_out[conid] = 0.0
contact_pos_out[conid] = wp.vec3(0.0)
contact_frame_out[conid] = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
contact_includemargin_out[conid] = 0.0
contact_friction_out[conid] = types.vec5(0.0, 0.0, 0.0, 0.0, 0.0)
contact_solref_out[conid] = wp.vec2(0.0, 0.0)
contact_solreffriction_out[conid] = wp.vec2(0.0, 0.0)
contact_solimp_out[conid] = types.vec5(0.0, 0.0, 0.0, 0.0, 0.0)
contact_dim_out[conid] = 0
contact_geom_out[conid] = wp.vec2i(0, 0)
for i in range(nefcaddress):
contact_efc_address_out[conid, i] = 0
contact_worldid_out[conid] = 0
contact_type_out[conid] = 0
contact_geomcollisionid_out[conid] = 0
reset_input = reset or wp.ones(d.nworld, dtype=bool)
wp.launch(reset_xfrc_applied, dim=(d.nworld, m.nbody, 6), inputs=[reset_input], outputs=[d.xfrc_applied])
wp.launch(
reset_qM,
dim=(d.nworld, d.qM.shape[1], d.qM.shape[2]),
inputs=[reset_input],
outputs=[d.qM],
)
# set mocap_pos/quat = body_pos/quat for mocap bodies
wp.launch(
reset_mocap,
dim=(d.nworld, m.nbody),
inputs=[m.body_mocapid, m.body_pos, m.body_quat, reset_input],
outputs=[d.mocap_pos, d.mocap_quat],
)
# clear contacts
wp.launch(
reset_contact,
dim=d.naconmax,
inputs=[d.nacon, reset_input, d.contact.efc_address.shape[1]],
outputs=[
d.contact.dist,
d.contact.pos,
d.contact.frame,
d.contact.includemargin,
d.contact.friction,
d.contact.solref,
d.contact.solreffriction,
d.contact.solimp,
d.contact.dim,
d.contact.geom,
d.contact.efc_address,
d.contact.worldid,
d.contact.type,
d.contact.geomcollisionid,
],
)
wp.launch(
reset_nworld,
dim=d.nworld,
inputs=[m.nq, m.nv, m.nu, m.na, m.neq, m.nsensordata, m.qpos0, m.eq_active0, d.nworld, reset_input],
outputs=[
d.solver_niter,
d.ne,
d.nf,
d.nl,
d.nefc,
d.time,
d.energy,
d.qpos,
d.qvel,
d.act,
d.qacc_warmstart,
d.ctrl,
d.qfrc_applied,
d.eq_active,
d.qacc,
d.act_dot,
d.sensordata,
d.nacon,
],
)
# kernel_analyzer: off
@wp.kernel
def _init_subtreemass(
body_mass_in: wp.array2d(dtype=float),
body_subtreemass_out: wp.array2d(dtype=float),
):
worldid, bodyid = wp.tid()
body_mass_id = worldid % body_mass_in.shape[0]
body_subtreemass_id = worldid % body_subtreemass_out.shape[0]
body_subtreemass_out[body_subtreemass_id, bodyid] = body_mass_in[body_mass_id, bodyid]
@wp.kernel
def _accumulate_subtreemass(
body_parentid: wp.array(dtype=int),
body_subtreemass_io: wp.array2d(dtype=float),
body_tree_: wp.array(dtype=int),
):
worldid, nodeid = wp.tid()
body_subtreemass_id = worldid % body_subtreemass_io.shape[0]
bodyid = body_tree_[nodeid]
parentid = body_parentid[bodyid]
if bodyid != 0:
wp.atomic_add(body_subtreemass_io, body_subtreemass_id, parentid, body_subtreemass_io[body_subtreemass_id, bodyid])
@wp.kernel
def _copy_qpos0_to_qpos(
qpos0: wp.array2d(dtype=float),
qpos_out: wp.array2d(dtype=float),
):
worldid, i = wp.tid()
qpos0_id = worldid % qpos0.shape[0]
qpos_out[worldid, i] = qpos0[qpos0_id, i]
@wp.kernel
def _copy_tendon_length0(
ten_length_in: wp.array2d(dtype=float),
tendon_length0_out: wp.array2d(dtype=float),
):
worldid, tenid = wp.tid()
tendon_length0_id = worldid % tendon_length0_out.shape[0]
tendon_length0_out[tendon_length0_id, tenid] = ten_length_in[worldid, tenid]
@wp.kernel
def _compute_meaninertia(
nv: int,
is_sparse: bool,
dof_Madr_in: wp.array(dtype=int),
qM_in: wp.array3d(dtype=float),
meaninertia_out: wp.array(dtype=float),
):
"""Compute mean diagonal inertia from qM at qpos0."""
worldid = wp.tid()
if nv == 0:
meaninertia_out[worldid % meaninertia_out.shape[0]] = 1.0 # Default from MuJoCo
return
total = float(0.0)
for i in range(nv):
if is_sparse:
# Sparse: qM is flattened lower triangular, diagonal at dof_Madr[i]
madr = dof_Madr_in[i]
total += qM_in[worldid, 0, madr]
else:
# Dense: qM is 2D matrix, diagonal at [i,i]
total += qM_in[worldid, i, i]
meaninertia_out[worldid % meaninertia_out.shape[0]] = total / float(nv)
@wp.kernel
def _set_unit_vector(
dofid_target: int,
unit_vec_out: wp.array2d(dtype=float),
):
worldid = wp.tid()
nv = unit_vec_out.shape[1]
for i in range(nv):
if i == dofid_target:
unit_vec_out[worldid, i] = 1.0
else:
unit_vec_out[worldid, i] = 0.0
@wp.kernel
def _extract_dof_A_diag(
dofid: int,
result_vec_in: wp.array2d(dtype=float),
dof_A_diag_out: wp.array2d(dtype=float),
):
worldid = wp.tid()
dof_A_diag_id = worldid % dof_A_diag_out.shape[0]
dof_A_diag_out[dof_A_diag_id, dofid] = result_vec_in[worldid, dofid]
@wp.kernel
def _finalize_dof_invweight0(
dof_jntid: wp.array(dtype=int),
jnt_type: wp.array(dtype=int),
jnt_dofadr: wp.array(dtype=int),
dof_A_diag_in: wp.array2d(dtype=float),
dof_invweight0_out: wp.array2d(dtype=float),
):
worldid, dofid = wp.tid()
dof_invweight0_id = worldid % dof_invweight0_out.shape[0]
dof_A_diag_id = worldid % dof_A_diag_in.shape[0]
jntid = dof_jntid[dofid]
jtype = jnt_type[jntid]
dofadr = jnt_dofadr[jntid]
if jtype == int(types.JointType.FREE.value):
# FREE joint: 6 DOFs, average first 3 (trans) and last 3 (rot) separately
if dofid < dofadr + 3:
avg = wp.static(1.0 / 3.0) * (
dof_A_diag_in[dof_A_diag_id, dofadr + 0]
+ dof_A_diag_in[dof_A_diag_id, dofadr + 1]
+ dof_A_diag_in[dof_A_diag_id, dofadr + 2]
)
else:
avg = wp.static(1.0 / 3.0) * (
dof_A_diag_in[dof_A_diag_id, dofadr + 3]
+ dof_A_diag_in[dof_A_diag_id, dofadr + 4]
+ dof_A_diag_in[dof_A_diag_id, dofadr + 5]
)
dof_invweight0_out[dof_invweight0_id, dofid] = avg
elif jtype == int(types.JointType.BALL.value):
# BALL joint: 3 DOFs, average all
avg = wp.static(1.0 / 3.0) * (
dof_A_diag_in[dof_A_diag_id, dofadr + 0]
+ dof_A_diag_in[dof_A_diag_id, dofadr + 1]
+ dof_A_diag_in[dof_A_diag_id, dofadr + 2]
)
dof_invweight0_out[dof_invweight0_id, dofid] = avg
else:
# HINGE/SLIDE: 1 DOF, no averaging
dof_invweight0_out[dof_invweight0_id, dofid] = dof_A_diag_in[dof_A_diag_id, dofid]
@wp.kernel
def _compute_body_jac_row(
nv: int,
bodyid_target: int,
row_idx: int,
body_parentid: wp.array(dtype=int),
body_rootid: wp.array(dtype=int),
body_dofadr: wp.array(dtype=int),
body_dofnum: wp.array(dtype=int),
dof_parentid: wp.array(dtype=int),
subtree_com_in: wp.array2d(dtype=wp.vec3),
xipos_in: wp.array2d(dtype=wp.vec3),
cdof_in: wp.array2d(dtype=wp.spatial_vector),
body_jac_row_out: wp.array2d(dtype=float),
):
worldid = wp.tid()
for i in range(nv):
body_jac_row_out[worldid, i] = 0.0
bodyid = bodyid_target
while bodyid > 0 and body_dofnum[bodyid] == 0:
bodyid = body_parentid[bodyid]
if bodyid == 0:
return
# Compute offset from point (xipos) to subtree_com of root body
point = xipos_in[worldid, bodyid_target]
offset = point - subtree_com_in[worldid, body_rootid[bodyid_target]]
# Get last dof that affects this body
dofid = body_dofadr[bodyid] + body_dofnum[bodyid] - 1
# Backward pass over dof ancestor chain
while dofid >= 0:
cdof = cdof_in[worldid, dofid]
cdof_ang = wp.spatial_top(cdof)
cdof_lin = wp.spatial_bottom(cdof)
if row_idx < 3:
tmp = wp.cross(cdof_ang, offset)
if row_idx == 0:
body_jac_row_out[worldid, dofid] = cdof_lin[0] + tmp[0]
elif row_idx == 1:
body_jac_row_out[worldid, dofid] = cdof_lin[1] + tmp[1]
else:
body_jac_row_out[worldid, dofid] = cdof_lin[2] + tmp[2]
else:
if row_idx == 3:
body_jac_row_out[worldid, dofid] = cdof_ang[0]
elif row_idx == 4:
body_jac_row_out[worldid, dofid] = cdof_ang[1]
else:
body_jac_row_out[worldid, dofid] = cdof_ang[2]
dofid = dof_parentid[dofid]
@wp.kernel
def _compute_body_A_diag_entry(
nv: int,
bodyid_target: int,
row_idx: int,
body_jac_row_in: wp.array2d(dtype=float),
result_vec_in: wp.array2d(dtype=float),
body_A_diag_out: wp.array3d(dtype=float),
):
worldid = wp.tid()
body_A_diag_id = worldid % body_A_diag_out.shape[0]
# A[row,row] = J[row] · inv(M) · J[row]' = J[row] · result_vec
dot_prod = float(0.0)
for i in range(nv):
dot_prod += body_jac_row_in[worldid, i] * result_vec_in[worldid, i]
body_A_diag_out[body_A_diag_id, bodyid_target, row_idx] = dot_prod
@wp.kernel
def _finalize_body_invweight0(
body_weldid: wp.array(dtype=int),
body_A_diag_in: wp.array3d(dtype=float),
body_invweight0_out: wp.array2d(dtype=wp.vec2),
):
worldid, bodyid = wp.tid()
body_invweight0_id = worldid % body_invweight0_out.shape[0]
body_A_diag_id = worldid % body_A_diag_in.shape[0]
# World body and static bodies have zero invweight
if bodyid == 0 or body_weldid[bodyid] == 0:
body_invweight0_out[body_invweight0_id, bodyid] = wp.vec2(0.0, 0.0)
return
# Average diagonal: trans = (A[0,0]+A[1,1]+A[2,2])/3, rot = (A[3,3]+A[4,4]+A[5,5])/3
inv_trans = wp.static(1.0 / 3.0) * (
body_A_diag_in[body_A_diag_id, bodyid, 0]
+ body_A_diag_in[body_A_diag_id, bodyid, 1]
+ body_A_diag_in[body_A_diag_id, bodyid, 2]
)
inv_rot = wp.static(1.0 / 3.0) * (
body_A_diag_in[body_A_diag_id, bodyid, 3]
+ body_A_diag_in[body_A_diag_id, bodyid, 4]
+ body_A_diag_in[body_A_diag_id, bodyid, 5]
)
# Prevent degenerate constraints: if one component is near zero, use the other as fallback
if inv_trans < mujoco.mjMINVAL and inv_rot > mujoco.mjMINVAL:
inv_trans = inv_rot # use rotation as fallback for translation
elif inv_rot < mujoco.mjMINVAL and inv_trans > mujoco.mjMINVAL:
inv_rot = inv_trans # use translation as fallback for rotation
body_invweight0_out[body_invweight0_id, bodyid] = wp.vec2(inv_trans, inv_rot)
@wp.kernel
def _copy_tendon_jacobian(
tenid_target: int,
ten_J_in: wp.array3d(dtype=float),
ten_J_vec_out: wp.array2d(dtype=float),
):
worldid = wp.tid()
nv = ten_J_in.shape[2]
for i in range(nv):
ten_J_vec_out[worldid, i] = ten_J_in[worldid, tenid_target, i]
@wp.kernel
def _compute_tendon_dot_product(
tenid_target: int,
nv: int,
ten_J_in: wp.array3d(dtype=float),
result_vec_in: wp.array2d(dtype=float),
tendon_invweight0_out: wp.array2d(dtype=float),
):
worldid = wp.tid()
tendon_invweight0_id = worldid % tendon_invweight0_out.shape[0]
dot_prod = float(0.0)
for i in range(nv):
dot_prod += ten_J_in[worldid, tenid_target, i] * result_vec_in[worldid, i]
tendon_invweight0_out[tendon_invweight0_id, tenid_target] = dot_prod
@wp.kernel
def _compute_cam_pos0(
cam_bodyid: wp.array(dtype=int),
cam_targetbodyid: wp.array(dtype=int),
cam_xpos_in: wp.array2d(dtype=wp.vec3),
cam_xmat_in: wp.array2d(dtype=wp.mat33),
xpos_in: wp.array2d(dtype=wp.vec3),
subtree_com_in: wp.array2d(dtype=wp.vec3),
cam_pos0_out: wp.array2d(dtype=wp.vec3),
cam_poscom0_out: wp.array2d(dtype=wp.vec3),
cam_mat0_out: wp.array2d(dtype=wp.mat33),
):
worldid, camid = wp.tid()
cam_pos0_id = worldid % cam_pos0_out.shape[0]
bodyid = cam_bodyid[camid]
targetid = cam_targetbodyid[camid]
cam_xpos = cam_xpos_in[worldid, camid]
cam_pos0_out[cam_pos0_id, camid] = cam_xpos - xpos_in[worldid, bodyid]
if targetid >= 0:
cam_poscom0_out[cam_pos0_id, camid] = cam_xpos - subtree_com_in[worldid, targetid]
else:
cam_poscom0_out[cam_pos0_id, camid] = cam_xpos - subtree_com_in[worldid, bodyid]
cam_mat0_out[cam_pos0_id, camid] = cam_xmat_in[worldid, camid]
@wp.kernel
def _compute_light_pos0(
light_bodyid: wp.array(dtype=int),
light_targetbodyid: wp.array(dtype=int),
light_xpos_in: wp.array2d(dtype=wp.vec3),
light_xdir_in: wp.array2d(dtype=wp.vec3),
xpos_in: wp.array2d(dtype=wp.vec3),
subtree_com_in: wp.array2d(dtype=wp.vec3),
light_pos0_out: wp.array2d(dtype=wp.vec3),
light_poscom0_out: wp.array2d(dtype=wp.vec3),
light_dir0_out: wp.array2d(dtype=wp.vec3),
):
worldid, lightid = wp.tid()
light_pos0_id = worldid % light_pos0_out.shape[0]
bodyid = light_bodyid[lightid]
targetid = light_targetbodyid[lightid]
light_xpos = light_xpos_in[worldid, lightid]
light_pos0_out[light_pos0_id, lightid] = light_xpos - xpos_in[worldid, bodyid]
if targetid >= 0:
light_poscom0_out[light_pos0_id, lightid] = light_xpos - subtree_com_in[worldid, targetid]
else:
light_poscom0_out[light_pos0_id, lightid] = light_xpos - subtree_com_in[worldid, bodyid]
light_dir0_out[light_pos0_id, lightid] = light_xdir_in[worldid, lightid]
@wp.kernel
def _copy_actuator_moment(
actid_target: int,
actuator_moment_in: wp.array3d(dtype=float),
act_moment_vec_out: wp.array2d(dtype=float),
):
worldid = wp.tid()
nv = actuator_moment_in.shape[2]
for i in range(nv):
act_moment_vec_out[worldid, i] = actuator_moment_in[worldid, actid_target, i]
@wp.kernel
def _compute_actuator_acc0(
actid_target: int,
nv: int,
result_vec_in: wp.array2d(dtype=float),
actuator_acc0_out: wp.array(dtype=float),
):
worldid = wp.tid()
norm_sq = float(0.0)
for i in range(nv):
norm_sq += result_vec_in[worldid, i] * result_vec_in[worldid, i]
actuator_acc0_out[actid_target] = wp.sqrt(norm_sq)
# kernel_analyzer: on
[docs]
def set_const_fixed(m: types.Model, d: types.Data):
"""Compute fixed quantities (independent of qpos0).
Computes:
- body_subtreemass: mass of body and all descendants (depends on body_mass)
- ngravcomp: count of bodies with gravity compensation (depends on body_gravcomp)
Args:
m: The model containing kinematic and dynamic information (device).
d: The data object containing the current state and output arrays (device).
"""
wp.launch(_init_subtreemass, dim=(d.nworld, m.nbody), inputs=[m.body_mass], outputs=[m.body_subtreemass])
for i in reversed(range(len(m.body_tree))):
body_tree = m.body_tree[i]
wp.launch(
_accumulate_subtreemass,
dim=(d.nworld, body_tree.size),
inputs=[m.body_parentid, m.body_subtreemass, body_tree],
)
# TODO(team): refactor for graph capture compatibility
body_gravcomp_np = m.body_gravcomp.numpy()
m.ngravcomp = int((body_gravcomp_np > 0.0).any(axis=0).sum())
[docs]
def set_const_0(m: types.Model, d: types.Data):
"""Compute quantities that depend on qpos0.
Computes:
- tendon_length0: tendon resting lengths
- dof_invweight0: inverse inertia for DOFs
- body_invweight0: inverse spatial inertia for bodies
- tendon_invweight0: inverse weight for tendons
- cam_pos0, cam_poscom0, cam_mat0: camera references
- light_pos0, light_poscom0, light_dir0: light references
- actuator_acc0: acceleration from unit actuator force
Args:
m: The model containing kinematic and dynamic information (device).
d: The data object containing the current state and output arrays (device).
"""
qpos_saved = wp.clone(d.qpos)
wp.launch(_copy_qpos0_to_qpos, dim=(d.nworld, m.nq), inputs=[m.qpos0], outputs=[d.qpos])
smooth.kinematics(m, d)
smooth.com_pos(m, d)
smooth.camlight(m, d)
smooth.flex(m, d)
smooth.tendon(m, d)
smooth.crb(m, d)
smooth.tendon_armature(m, d)
smooth.factor_m(m, d)
smooth.transmission(m, d)
# Compute meaninertia from qM diagonal at qpos0
wp.launch(
_compute_meaninertia,
dim=d.nworld,
inputs=[m.nv, m.is_sparse, m.dof_Madr, d.qM],
outputs=[m.stat.meaninertia],
)
wp.launch(_copy_tendon_length0, dim=(d.nworld, m.ntendon), inputs=[d.ten_length], outputs=[m.tendon_length0])
# dof_invweight0: computed per joint with averaging for multi-DOF joints
# FREE: 6 DOFs, trans gets mean(A[0:3]), rot gets mean(A[3:6])
# BALL: 3 DOFs, all get mean(A[0:3])
# HINGE/SLIDE: 1 DOF, gets A[0,0]
if m.nv > 0:
unit_vec = wp.zeros((d.nworld, m.nv), dtype=float)
result_vec = wp.zeros((d.nworld, m.nv), dtype=float)
dof_A_diag = wp.zeros((d.nworld, m.nv), dtype=float)
# TODO(team): more efficient approach instead of looping over nv?
for dofid in range(m.nv):
wp.launch(_set_unit_vector, dim=d.nworld, inputs=[dofid], outputs=[unit_vec])
smooth.solve_m(m, d, result_vec, unit_vec)
wp.launch(_extract_dof_A_diag, dim=d.nworld, inputs=[dofid, result_vec], outputs=[dof_A_diag])
wp.launch(
_finalize_dof_invweight0,
dim=(d.nworld, m.nv),
inputs=[m.dof_jntid, m.jnt_type, m.jnt_dofadr, dof_A_diag],
outputs=[m.dof_invweight0],
)
# body_invweight0: computed as mean diagonal of J * inv(M) * J'
# where J is the 6xnv body Jacobian (3 rows translation, 3 rows rotation)
if m.nv > 0:
body_jac_row = wp.zeros((d.nworld, m.nv), dtype=float)
body_result_vec = wp.zeros((d.nworld, m.nv), dtype=float)
body_A_diag = wp.zeros((d.nworld, m.nbody, 6), dtype=float)
# TODO(team): more efficient approach instead of nested iterations?
for bodyid in range(1, m.nbody):
for row_idx in range(6):
wp.launch(
_compute_body_jac_row,
dim=d.nworld,
inputs=[
m.nv,
bodyid,
row_idx,
m.body_parentid,
m.body_rootid,
m.body_dofadr,
m.body_dofnum,
m.dof_parentid,
d.subtree_com,
d.xipos,
d.cdof,
],
outputs=[body_jac_row],
)
smooth.solve_m(m, d, body_result_vec, body_jac_row)
wp.launch(
_compute_body_A_diag_entry,
dim=d.nworld,
inputs=[m.nv, bodyid, row_idx, body_jac_row, body_result_vec],
outputs=[body_A_diag],
)
wp.launch(
_finalize_body_invweight0,
dim=(d.nworld, m.nbody),
inputs=[m.body_weldid, body_A_diag],
outputs=[m.body_invweight0],
)
else:
m.body_invweight0.zero_()
# tendon_invweight0[t] = J_t * inv(M) * J_t'
if m.ntendon > 0:
ten_J_vec = wp.zeros((d.nworld, m.nv), dtype=float)
ten_result_vec = wp.zeros((d.nworld, m.nv), dtype=float)
for tenid in range(m.ntendon):
wp.launch(_copy_tendon_jacobian, dim=d.nworld, inputs=[tenid, d.ten_J], outputs=[ten_J_vec])
smooth.solve_m(m, d, ten_result_vec, ten_J_vec)
wp.launch(
_compute_tendon_dot_product,
dim=d.nworld,
inputs=[tenid, m.nv, d.ten_J, ten_result_vec],
outputs=[m.tendon_invweight0],
)
wp.launch(
_compute_cam_pos0,
dim=(d.nworld, m.ncam),
inputs=[m.cam_bodyid, m.cam_targetbodyid, d.cam_xpos, d.cam_xmat, d.xpos, d.subtree_com],
outputs=[m.cam_pos0, m.cam_poscom0, m.cam_mat0],
)
wp.launch(
_compute_light_pos0,
dim=(d.nworld, m.nlight),
inputs=[m.light_bodyid, m.light_targetbodyid, d.light_xpos, d.light_xdir, d.xpos, d.subtree_com],
outputs=[m.light_pos0, m.light_poscom0, m.light_dir0],
)
# actuator_acc0[i] = ||inv(M) * actuator_moment[i]|| - acceleration from unit actuator force
if m.nu > 0 and m.nv > 0:
act_moment_vec = wp.zeros((d.nworld, m.nv), dtype=float)
act_result_vec = wp.zeros((d.nworld, m.nv), dtype=float)
for actid in range(m.nu):
wp.launch(_copy_actuator_moment, dim=d.nworld, inputs=[actid, d.actuator_moment], outputs=[act_moment_vec])
smooth.solve_m(m, d, act_result_vec, act_moment_vec)
wp.launch(_compute_actuator_acc0, dim=d.nworld, inputs=[actid, m.nv, act_result_vec], outputs=[m.actuator_acc0])
wp.copy(d.qpos, qpos_saved)
[docs]
def set_const(m: types.Model, d: types.Data):
"""Recomputes qpos0-dependent constant model fields.
This function propagates changes from some model fields to derived fields,
allowing modifications that would otherwise be unsafe. It should be called
after modifying model parameters at runtime.
Model fields that can be modified safely with set_const:
Field | Notes
---------------------------------|----------------------------------------------
qpos0, qpos_spring |
body_mass, body_inertia, | Mass and inertia are usually scaled together
body_ipos, body_iquat | since inertia is sum(m * r^2).
body_pos, body_quat | Unsafe for static bodies (invalidates BVH).
body_gravcomp | If changing from 0 to >0 bodies, required.
dof_armature |
eq_data | For connect/weld, offsets computed if not set.
hfield_size |
tendon_stiffness, tendon_damping | Only if changing from/to zero.
actuator_gainprm, actuator_biasprm | For position actuators with dampratio.
For selective updates, use the sub-functions directly based on what changed:
Modified Field | Call
----------------|------------------
body_mass | set_const
body_gravcomp | set_const_fixed
body_inertia | set_const_0
qpos0 | set_const_0
Computes:
- Fixed quantities (via set_const_fixed):
- body_subtreemass: mass of body and all descendants
- ngravcomp: count of bodies with gravity compensation
- qpos0-dependent quantities (via set_const_0):
- tendon_length0: tendon resting lengths
- dof_invweight0: inverse inertia for DOFs
- body_invweight0: inverse spatial inertia for bodies
- tendon_invweight0: inverse weight for tendons
- cam_pos0, cam_poscom0, cam_mat0: camera references
- light_pos0, light_poscom0, light_dir0: light references
- actuator_acc0: acceleration from unit actuator force
Skips: dof_M0, actuator_length0 (not in mjwarp).
Args:
m: The model containing kinematic and dynamic information (device).
d: The data object containing the current state and output arrays (device).
"""
set_const_fixed(m, d)
set_const_0(m, d)
def override_model(model: types.Model | mujoco.MjModel, overrides: dict[str, Any] | Sequence[str]):
"""Overrides model parameters.
Overrides are of the format:
opt.iterations = 1
opt.ls_parallel = True
opt.cone = pyramidal
opt.disableflags = contact | spring
"""
enum_fields = {
"opt.broadphase": types.BroadphaseType,
"opt.broadphase_filter": types.BroadphaseFilter,
"opt.cone": types.ConeType,
"opt.disableflags": types.DisableBit,
"opt.enableflags": types.EnableBit,
"opt.integrator": types.IntegratorType,
"opt.solver": types.SolverType,
}
# MuJoCo pybind11 enums don't support iteration, so we provide explicit mappings
mj_enum_fields = {
"opt.jacobian": {
"DENSE": mujoco.mjtJacobian.mjJAC_DENSE,
"SPARSE": mujoco.mjtJacobian.mjJAC_SPARSE,
"AUTO": mujoco.mjtJacobian.mjJAC_AUTO,
},
}
mjw_only_fields = {"opt.broadphase", "opt.broadphase_filter", "opt.ls_parallel", "opt.graph_conditional"}
mj_only_fields = {"opt.jacobian"}
if not isinstance(overrides, dict):
overrides_dict = {}
for override in overrides:
if "=" not in override:
raise ValueError(f"Invalid override format: {override}")
k, v = override.split("=", 1)
overrides_dict[k.strip()] = v.strip()
overrides = overrides_dict
for key, val in overrides.items():
# skip overrides on MjModel for properties that are only on mjw.Model
if key in mjw_only_fields and isinstance(model, mujoco.MjModel):
continue
if key in mj_only_fields and isinstance(model, types.Model):
continue
obj, attrs = model, key.split(".")
for i, attr in enumerate(attrs):
if not hasattr(obj, attr):
raise ValueError(f"Unrecognized model field: {key}")
if i < len(attrs) - 1:
obj = getattr(obj, attr)
continue
typ = type(getattr(obj, attr))
if key in mj_enum_fields and isinstance(val, str):
enum_member = val.strip().upper()
if enum_member not in mj_enum_fields[key]:
raise ValueError(f"Unrecognized enum value for {key}: {enum_member}")
val = mj_enum_fields[key][enum_member]
elif key in enum_fields and isinstance(val, str):
# special case: enum value
enum_members = val.split("|")
val = 0
for enum_member in enum_members:
enum_member = enum_member.strip().upper()
if enum_member not in enum_fields[key].__members__:
raise ValueError(f"Unrecognized enum value for {enum_fields[key].__name__}: {enum_member}")
val |= int(enum_fields[key][enum_member])
elif typ is bool and isinstance(val, str):
# special case: "true", "TRUE", "false", "FALSE" etc.
if val.upper() not in ("TRUE", "FALSE"):
raise ValueError(f"Unrecognized value for field: {key}")
val = val.upper() == "TRUE"
elif typ is wp.array and isinstance(val, str):
arr = getattr(obj, attr)
floats = [float(p) for p in val.strip("[]").split()]
val = wp.array([arr.dtype(*floats)], dtype=arr.dtype)
elif typ is np.ndarray and isinstance(val, str):
arr = getattr(obj, attr)
val = np.array([float(p) for p in val.strip("[]").split()], dtype=arr.dtype)
else:
val = typ(val)
setattr(obj, attr, val)
def find_keys(model: mujoco.MjModel, keyname_prefix: str) -> list[int]:
"""Finds keyframes that start with keyname_prefix."""
keys = []
for keyid in range(model.nkey):
name = mujoco.mj_id2name(model, mujoco.mjtObj.mjOBJ_KEY, keyid)
if name.startswith(keyname_prefix):
keys.append(keyid)
return keys
def make_trajectory(model: mujoco.MjModel, keys: list[int]) -> np.ndarray:
"""Make a ctrl trajectory with linear interpolation."""
ctrls = []
prev_ctrl_key = np.zeros(model.nu, dtype=np.float64)
prev_time, time = 0.0, 0.0
for key in keys:
ctrl_key, ctrl_time = model.key_ctrl[key], model.key_time[key]
if not ctrls and ctrl_time != 0.0:
raise ValueError("first keyframe must have time 0.0")
elif ctrls and ctrl_time <= prev_time:
raise ValueError("keyframes must be in time order")
while time < ctrl_time:
frac = (time - prev_time) / (ctrl_time - prev_time)
ctrls.append(prev_ctrl_key * (1 - frac) + ctrl_key * frac)
time += model.opt.timestep
ctrls.append(ctrl_key)
time += model.opt.timestep
prev_ctrl_key = ctrl_key
prev_time = time
return np.array(ctrls)
@wp.kernel
def _build_rays(
# In:
offset: int,
img_w: int,
img_h: int,
projection: int,
fovy: float,
sensorsize: wp.vec2,
intrinsic: wp.vec4,
znear: float,
# Out:
ray_out: wp.array(dtype=wp.vec3),
):
xid, yid = wp.tid()
ray_out[offset + xid + yid * img_w] = render_util.compute_ray(
projection, fovy, sensorsize, intrinsic, img_w, img_h, xid, yid, znear
)
[docs]
def create_render_context(
mjm: mujoco.MjModel,
nworld: int = 1,
cam_res: list[tuple[int, int]] | tuple[int, int] | None = None,
render_rgb: list[bool] | bool | None = None,
render_depth: list[bool] | bool | None = None,
use_textures: bool = True,
use_shadows: bool = False,
enabled_geom_groups: list[int] = [0, 1, 2],
cam_active: list[bool] | None = None,
flex_render_smooth: bool = True,
use_precomputed_rays: bool = True,
) -> types.RenderContext:
"""Creates a render context on device.
Args:
mjm: The model containing kinematic and dynamic information on host.
nworld: The number of worlds.
cam_res: The width and height to render each camera image. If None, uses the
MuJoCo model values.
render_rgb: Whether to render RGB images. If None, uses the MuJoCo model values.
render_depth: Whether to render depth images. If None, uses the MuJoCo model values.
use_textures: Whether to use textures.
use_shadows: Whether to use shadows.
enabled_geom_groups: The geom groups to render.
cam_active: List of booleans indicating which cameras to include in rendering.
If None, all cameras are included.
flex_render_smooth: Whether to render flex meshes smoothly.
use_precomputed_rays: Use precomputed rays instead of computing during rendering.
When using domain randomization for camera intrinsics, set to False.
Returns:
The render context containing rendering fields and output arrays on device.
"""
mjd = mujoco.MjData(mjm)
mujoco.mj_forward(mjm, mjd)
# TODO(team): remove after mjwarp depends on warp-lang >= 1.12 in pyproject.toml
if use_textures and not hasattr(wp, "Texture2D"):
warnings.warn("Textures require warp >= 1.12. Disabling textures.")
use_textures = False
# Mesh BVHs
nmesh = mjm.nmesh
geom_enabled_mask = np.isin(mjm.geom_group, list(enabled_geom_groups))
mesh_geom_mask = geom_enabled_mask & (mjm.geom_type == types.GeomType.MESH) & (mjm.geom_dataid >= 0)
used_mesh_id = set(mjm.geom_dataid[mesh_geom_mask].astype(int))
geom_enabled_idx = np.nonzero(geom_enabled_mask)[0]
mesh_registry = {}
mesh_bvh_id = [wp.uint64(0) for _ in range(nmesh)]
mesh_bounds_size = [wp.vec3(0.0, 0.0, 0.0) for _ in range(nmesh)]
for mid in used_mesh_id:
mesh, half = bvh.build_mesh_bvh(mjm, mid)
mesh_registry[mesh.id] = mesh
mesh_bvh_id[mid] = mesh.id
mesh_bounds_size[mid] = half
mesh_bvh_id_arr = wp.array(mesh_bvh_id, dtype=wp.uint64)
mesh_bounds_size_arr = wp.array(mesh_bounds_size, dtype=wp.vec3)
# HField BVHs
nhfield = mjm.nhfield
hfield_geom_mask = geom_enabled_mask & (mjm.geom_type == types.GeomType.HFIELD) & (mjm.geom_dataid >= 0)
used_hfield_id = set(mjm.geom_dataid[hfield_geom_mask].astype(int))
hfield_registry = {}
hfield_bvh_id = [wp.uint64(0) for _ in range(nhfield)]
hfield_bounds_size = [wp.vec3(0.0, 0.0, 0.0) for _ in range(nhfield)]
for hid in used_hfield_id:
hmesh, hhalf = bvh.build_hfield_bvh(mjm, hid)
hfield_registry[hmesh.id] = hmesh
hfield_bvh_id[hid] = hmesh.id
hfield_bounds_size[hid] = hhalf
hfield_bvh_id_arr = wp.array(hfield_bvh_id, dtype=wp.uint64)
hfield_bounds_size_arr = wp.array(hfield_bounds_size, dtype=wp.vec3)
# Flex BVHs
flex_bvh_id = wp.uint64(0)
flex_group_root = wp.zeros(nworld, dtype=int)
flex_mesh = None
flex_face_point = None
flex_elemdataadr = None
flex_shell = None
flex_shelldataadr = None
flex_faceadr = None
flex_nface = 0
flex_radius = None
flex_workadr = None
flex_worknum = None
flex_nwork = 0
if mjm.nflex > 0:
(
fmesh,
face_point,
flex_group_roots,
flex_shell_data,
flex_faceadr_data,
flex_nface,
) = bvh.build_flex_bvh(mjm, mjd, nworld)
flex_mesh = fmesh
flex_bvh_id = fmesh.id
flex_face_point = face_point
flex_group_root = flex_group_roots
flex_elemdataadr = wp.array(mjm.flex_elemdataadr, dtype=int)
flex_shell = flex_shell_data
flex_shelldataadr = wp.array(mjm.flex_shelldataadr, dtype=int)
flex_faceadr = wp.array(flex_faceadr_data, dtype=int)
flex_radius = wp.array(mjm.flex_radius, dtype=float)
# precompute work item layout for unified refit kernel
nflex = mjm.nflex
workadr = np.zeros(nflex, dtype=np.int32)
worknum = np.zeros(nflex, dtype=np.int32)
cumsum = 0
for f in range(nflex):
workadr[f] = cumsum
if mjm.flex_dim[f] == 2:
worknum[f] = mjm.flex_elemnum[f] + mjm.flex_shellnum[f]
else:
worknum[f] = mjm.flex_shellnum[f]
cumsum += worknum[f]
flex_workadr = wp.array(workadr, dtype=int)
flex_worknum = wp.array(worknum, dtype=int)
flex_nwork = int(cumsum)
textures_registry = []
# TODO: remove after mjwarp depends on warp-lang >= 1.12 in pyproject.toml
if hasattr(wp, "Texture2D"):
for i in range(mjm.ntex):
textures_registry.append(render_util.create_warp_texture(mjm, i))
textures = wp.array(textures_registry, dtype=wp.Texture2D)
else:
# Dummy array when texture support isn't available (warp < 1.12)
textures = wp.zeros(1, dtype=int)
# Filter active cameras
if cam_active is not None:
assert len(cam_active) == mjm.ncam, f"cam_active must have length {mjm.ncam} (got {len(cam_active)})"
active_cam_indices = np.nonzero(cam_active)[0]
else:
active_cam_indices = list(range(mjm.ncam))
ncam = len(active_cam_indices)
if cam_res is not None:
if isinstance(cam_res, tuple):
cam_res = [cam_res] * ncam
assert len(cam_res) == ncam, (
f"Camera resolutions must be provided for all active cameras (got {len(cam_res)}, expected {ncam})"
)
active_cam_res = cam_res
else:
active_cam_res = mjm.cam_resolution[active_cam_indices]
cam_res_arr = wp.array(active_cam_res, dtype=wp.vec2i)
if render_rgb and isinstance(render_rgb, bool):
render_rgb = [render_rgb] * ncam
elif render_rgb is None:
# TODO: remove after mjwarp depends on mujoco >= 3.4.1 in pyproject.toml
if BLEEDING_EDGE_MUJOCO:
render_rgb = [mjm.cam_output[i] & mujoco.mjtCamOutBit.mjCAMOUT_RGB for i in active_cam_indices]
else:
render_rgb = [True] * ncam
if render_depth and isinstance(render_depth, bool):
render_depth = [render_depth] * ncam
elif render_depth is None:
# TODO: remove after mjwarp depends on mujoco >= 3.4.1 in pyproject.toml
if BLEEDING_EDGE_MUJOCO:
render_depth = [mjm.cam_output[i] & mujoco.mjtCamOutBit.mjCAMOUT_DEPTH for i in active_cam_indices]
else:
render_depth = [True] * ncam
assert len(render_rgb) == ncam and len(render_depth) == ncam, (
f"Render RGB and depth must be provided for all active cameras (got {len(render_rgb)}, {len(render_depth)}, expected {ncam})"
)
rgb_adr = -1 * np.ones(ncam, dtype=int)
depth_adr = -1 * np.ones(ncam, dtype=int)
cam_res_np = cam_res_arr.numpy()
ri = 0
di = 0
total = 0
for idx in range(ncam):
if render_rgb[idx]:
rgb_adr[idx] = ri
ri += cam_res_np[idx][0] * cam_res_np[idx][1]
if render_depth[idx]:
depth_adr[idx] = di
di += cam_res_np[idx][0] * cam_res_np[idx][1]
total += cam_res_np[idx][0] * cam_res_np[idx][1]
znear = mjm.vis.map.znear * mjm.stat.extent
ray = wp.zeros(int(total), dtype=wp.vec3)
# TODO: remove after mjwarp depends on mujoco >= 3.4.1 in pyproject.toml
cam_projection = np.zeros(mjm.ncam, dtype=int)
if BLEEDING_EDGE_MUJOCO:
cam_projection = mjm.cam_projection
offset = 0
for idx, cam_id in enumerate(active_cam_indices):
img_w = cam_res_np[idx][0]
img_h = cam_res_np[idx][1]
wp.launch(
kernel=_build_rays,
dim=(img_w, img_h),
inputs=[
offset,
img_w,
img_h,
int(cam_projection[cam_id]),
float(mjm.cam_fovy[cam_id]),
wp.vec2(mjm.cam_sensorsize[cam_id]),
wp.vec4(mjm.cam_intrinsic[cam_id]),
znear,
],
outputs=[ray],
)
offset += img_w * img_h
bvh_ngeom = len(geom_enabled_idx)
rc = types.RenderContext(
nrender=ncam,
cam_res=cam_res_arr,
cam_id_map=wp.array(active_cam_indices, dtype=int),
use_textures=use_textures,
use_shadows=use_shadows,
background_color=render_util.pack_rgba_to_uint32(0.1 * 255.0, 0.1 * 255.0, 0.2 * 255.0, 1.0 * 255.0),
use_precomputed_rays=use_precomputed_rays,
bvh_ngeom=bvh_ngeom,
enabled_geom_ids=wp.array(geom_enabled_idx, dtype=int),
mesh_registry=mesh_registry,
mesh_bvh_id=mesh_bvh_id_arr,
mesh_bounds_size=mesh_bounds_size_arr,
mesh_texcoord=wp.array(mjm.mesh_texcoord, dtype=wp.vec2),
mesh_texcoord_offsets=wp.array(mjm.mesh_texcoordadr, dtype=int),
mesh_facetexcoord=wp.array(mjm.mesh_facetexcoord, dtype=wp.vec3i),
textures=textures,
textures_registry=textures_registry,
hfield_registry=hfield_registry,
hfield_bvh_id=hfield_bvh_id_arr,
hfield_bounds_size=hfield_bounds_size_arr,
flex_mesh=flex_mesh,
flex_rgba=wp.array(mjm.flex_rgba, dtype=wp.vec4),
flex_bvh_id=flex_bvh_id,
flex_face_point=flex_face_point,
flex_faceadr=flex_faceadr,
flex_nface=flex_nface,
flex_nwork=flex_nwork,
flex_group_root=flex_group_root,
flex_elemdataadr=flex_elemdataadr,
flex_shell=flex_shell,
flex_shelldataadr=flex_shelldataadr,
flex_radius=flex_radius,
flex_workadr=flex_workadr,
flex_worknum=flex_worknum,
flex_render_smooth=flex_render_smooth,
bvh=None,
bvh_id=None,
lower=wp.zeros(nworld * bvh_ngeom, dtype=wp.vec3),
upper=wp.zeros(nworld * bvh_ngeom, dtype=wp.vec3),
group=wp.zeros(nworld * bvh_ngeom, dtype=int),
group_root=wp.zeros(nworld, dtype=int),
ray=ray,
rgb_data=wp.zeros((nworld, ri), dtype=wp.uint32),
rgb_adr=wp.array(rgb_adr, dtype=int),
depth_data=wp.zeros((nworld, di), dtype=wp.float32),
depth_adr=wp.array(depth_adr, dtype=int),
render_rgb=wp.array(render_rgb, dtype=bool),
render_depth=wp.array(render_depth, dtype=bool),
znear=znear,
total_rays=int(total),
)
bvh.build_scene_bvh(mjm, mjd, rc, nworld)
return rc