Source code for mujoco.mjx._src.sensor

# Copyright 2023 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Sensor functions."""

import jax
from jax import numpy as jp
import mujoco
# pylint: disable=g-importing-member
from mujoco.mjx._src import math
from mujoco.mjx._src import ray
from mujoco.mjx._src import smooth
from mujoco.mjx._src import support
from mujoco.mjx._src.types import Data
from mujoco.mjx._src.types import DataJAX
from mujoco.mjx._src.types import DisableBit
from mujoco.mjx._src.types import Model
from mujoco.mjx._src.types import ModelJAX
from mujoco.mjx._src.types import ObjType
from mujoco.mjx._src.types import SensorType
from mujoco.mjx._src.types import TrnType
# pylint: enable=g-importing-member
import numpy as np


def _apply_cutoff(
    sensor: jax.Array, cutoff: jax.Array, data_type: int
) -> jax.Array:
  """Clip sensor to cutoff value."""

  @jax.vmap
  def fn(sensor, cutoff):
    if data_type == mujoco.mjtDataType.mjDATATYPE_REAL:
      return jp.where(cutoff > 0, jp.clip(sensor, -cutoff, cutoff), sensor)
    elif data_type == mujoco.mjtDataType.mjDATATYPE_POSITIVE:
      return jp.where(cutoff > 0, jp.minimum(sensor, cutoff), sensor)
    else:
      return sensor

  return fn(sensor, cutoff)


[docs] def sensor_pos(m: Model, d: Data) -> Data: """Compute position-dependent sensors values.""" if not isinstance(m._impl, ModelJAX) or not isinstance(d._impl, DataJAX): raise ValueError('sensor_pos requires JAX backend implementation.') if m.opt.disableflags & DisableBit.SENSOR: return d # position and orientation by object type objtype_data = { ObjType.UNKNOWN: ( np.zeros((1, 3)), np.expand_dims(np.eye(3), axis=0), ), # world ObjType.BODY: (d.xipos, d.ximat), ObjType.XBODY: (d.xpos, d.xmat), ObjType.GEOM: (d.geom_xpos, d.geom_xmat), ObjType.SITE: (d.site_xpos, d.site_xmat), ObjType.CAMERA: (d.cam_xpos, d.cam_xmat), } # frame axis indexing frame_axis = { SensorType.FRAMEXAXIS: 0, SensorType.FRAMEYAXIS: 1, SensorType.FRAMEZAXIS: 2, } stage_pos = m.sensor_needstage == mujoco.mjtStage.mjSTAGE_POS sensors, adrs = [], [] for sensor_type in set(m.sensor_type[stage_pos]): idx = m.sensor_type == sensor_type objid = m.sensor_objid[idx] objtype = m.sensor_objtype[idx] refid = m.sensor_refid[idx] reftype = m.sensor_reftype[idx] adr = m.sensor_adr[idx] cutoff = m.sensor_cutoff[idx] data_type = m.sensor_datatype[idx] if sensor_type == SensorType.MAGNETOMETER: sensor = jax.vmap(lambda xmat: xmat.T @ m.opt.magnetic)( d.site_xmat[objid] ) adr = (adr[:, None] + np.arange(3)[None]).reshape(-1) elif sensor_type == SensorType.CAMPROJECTION: @jax.vmap def _cam_project( target_xpos, xpos, xmat, res, fovy, intrinsic, sensorsize, focal_flag ): translation = jp.eye(4).at[0:3, 3].set(-xpos) rotation = jp.eye(4).at[:3, :3].set(xmat.T) # focal transformation matrix (3 x 4) f = 0.5 / jp.tan(fovy * jp.pi / 360.0) * res[1] fx, fy = jp.where( focal_flag, intrinsic[:2] / (sensorsize[:2] + mujoco.mjMINVAL) * res[:2], f, ) # add mjMINVAL to denominator to prevent divide by zero warning focal = jp.array([[-fx, 0, 0, 0], [0, fy, 0, 0], [0, 0, 1.0, 0]]) # image matrix (3 x 3) image = jp.eye(3).at[:2, 2].set(res[0:2] / 2.0) # projection matrix (3 x 4): product of all 4 matrices proj = image @ focal @ rotation @ translation # projection matrix multiplies homogenous [x, y, z, 1] vectors pos_hom = jp.append(target_xpos, 1.0) # project world coordinates into pixel space, see: # https://en.wikipedia.org/wiki/3D_projection#Mathematical_formula pixel_coord_hom = proj @ pos_hom # avoid dividing by tiny numbers denom = pixel_coord_hom[2] denom = jp.where( jp.abs(denom) < mujoco.mjMINVAL, jp.clip(denom, -mujoco.mjMINVAL, mujoco.mjMINVAL), denom, ) # compute projection sensor = pixel_coord_hom / denom return sensor[:2] sensorsize = m.cam_sensorsize[refid] intrinsic = m.cam_intrinsic[refid] fovy = m.cam_fovy[refid] res = m.cam_resolution[refid] focal_flag = np.logical_and(sensorsize[:, 0] != 0, sensorsize[:, 1] != 0) target_xpos = d.site_xpos[objid] xpos = d.cam_xpos[refid] xmat = d.cam_xmat[refid] sensor = _cam_project( target_xpos, xpos, xmat, res, fovy, intrinsic, sensorsize, focal_flag ) adr = (adr[:, None] + np.arange(2)[None]).reshape(-1) elif sensor_type == SensorType.RANGEFINDER: site_bodyid = m.site_bodyid[objid] for sid in set(site_bodyid): idxs = sid == site_bodyid objids = objid[idxs] site_xpos = d.site_xpos[objids] site_mat = d.site_xmat[objids].reshape((-1, 9))[:, np.array([2, 5, 8])] cutoffs = cutoff[idxs] sensor, _ = jax.vmap( ray.ray, in_axes=(None, None, 0, 0, None, None, None) )(m, d, site_xpos, site_mat, (), True, sid) sensors.append(_apply_cutoff(sensor, cutoffs, data_type[0])) adrs.append(adr[idxs]) continue # avoid adding to sensors/adrs list a second time elif sensor_type == SensorType.JOINTPOS: sensor = d.qpos[m.jnt_qposadr[objid]] elif sensor_type == SensorType.TENDONPOS: sensor = d.ten_length[objid] elif sensor_type == SensorType.ACTUATORPOS: sensor = d.actuator_length[objid] elif sensor_type == SensorType.BALLQUAT: jnt_qposadr = m.jnt_qposadr[objid, None] + np.arange(4)[None] quat = d.qpos[jnt_qposadr] sensor = jax.vmap(math.normalize)(quat) adr = (adr[:, None] + np.arange(4)[None]).reshape(-1) elif sensor_type == SensorType.FRAMEPOS: def _framepos(xpos, xpos_ref, xmat_ref, refid): return jp.where(refid == -1, xpos, xmat_ref.T @ (xpos - xpos_ref)) # evaluate for valid object and reference object type pairs for ot, rt in set(zip(objtype, reftype)): idxt = (objtype == ot) & (reftype == rt) refidt = refid[idxt] xpos, _ = objtype_data[ot] xpos_ref, xmat_ref = objtype_data[rt] xpos = xpos[objid[idxt]] xpos_ref = xpos_ref[refidt] xmat_ref = xmat_ref[refidt] cutofft = cutoff[idxt] sensor = jax.vmap(_framepos)(xpos, xpos_ref, xmat_ref, refidt) adrt = adr[idxt, None] + np.arange(3)[None] sensors.append(_apply_cutoff(sensor, cutofft, data_type[0]).reshape(-1)) adrs.append(adrt.reshape(-1)) continue # avoid adding to sensors/adrs list a second time elif sensor_type in frame_axis: def _frameaxis(xmat, xmat_ref, refid): axis = xmat[:, frame_axis[sensor_type]] return jp.where(refid == -1, axis, xmat_ref.T @ axis) # evaluate for valid object and reference object type pairs for ot, rt in set(zip(objtype, reftype)): idxt = (objtype == ot) & (reftype == rt) refidt = refid[idxt] _, xmat = objtype_data[ot] _, xmat_ref = objtype_data[rt] xmat = xmat[objid[idxt]] xmat_ref = xmat_ref[refidt] cutofft = cutoff[idxt] sensor = jax.vmap(_frameaxis)(xmat, xmat_ref, refidt) adrt = adr[idxt, None] + np.arange(3)[None] sensors.append(_apply_cutoff(sensor, cutofft, data_type[0]).reshape(-1)) adrs.append(adrt.reshape(-1)) continue # avoid adding to sensors/adrs list a second time elif sensor_type == SensorType.FRAMEQUAT: def _quat(otype, oid): if otype == ObjType.XBODY: return d.xquat[oid] elif otype == ObjType.BODY: return jax.vmap(math.quat_mul)(d.xquat[oid], m.body_iquat[oid]) elif otype == ObjType.GEOM: return jax.vmap(math.quat_mul)( d.xquat[m.geom_bodyid[oid]], m.geom_quat[oid] ) elif otype == ObjType.SITE: return jax.vmap(math.quat_mul)( d.xquat[m.site_bodyid[oid]], m.site_quat[oid] ) elif otype == ObjType.CAMERA: return jax.vmap(math.quat_mul)( d.xquat[m.cam_bodyid[oid]], m.cam_quat[oid] ) elif otype == ObjType.UNKNOWN: return jp.tile(jp.array([1.0, 0.0, 0.0, 0.0]), (oid.size, 1)) else: raise ValueError(f'Unknown object type: {otype}') # evaluate for valid object and reference object type pairs for ot, rt in set(zip(objtype, reftype)): idxt = (objtype == ot) & (reftype == rt) objidt = objid[idxt] refidt = refid[idxt] quat = _quat(ot, objidt) refquat = _quat(rt, refidt) cutofft = cutoff[idxt] sensor = jax.vmap( lambda q, r, rid: jp.where( rid == -1, q, math.quat_mul(math.quat_inv(r), q) ) )(quat, refquat, refidt) adrt = adr[idxt, None] + np.arange(4)[None] sensors.append(_apply_cutoff(sensor, cutofft, data_type[0]).reshape(-1)) adrs.append(adrt.reshape(-1)) continue # avoid adding to sensors/adrs list a second time elif sensor_type == SensorType.SUBTREECOM: sensor = d.subtree_com[objid] adr = (adr[:, None] + np.arange(3)[None]).reshape(-1) elif sensor_type == SensorType.CLOCK: sensor = jp.repeat(d.time, sum(idx)) else: # TODO(taylorhowell): raise error after adding sensor check to io.py continue # unsupported sensor type sensors.append(_apply_cutoff(sensor, cutoff, data_type[0]).reshape(-1)) adrs.append(adr) if not adrs: return d sensordata = d.sensordata.at[np.concatenate(adrs)].set( jp.concatenate(sensors) ) return d.replace(sensordata=sensordata)
[docs] def sensor_vel(m: Model, d: Data) -> Data: """Compute velocity-dependent sensors values.""" if not isinstance(m._impl, ModelJAX) or not isinstance(d._impl, DataJAX): raise ValueError('sensor_vel requires JAX backend implementation.') if m.opt.disableflags & DisableBit.SENSOR: return d # position and orientation by object type objtype_data = { ObjType.UNKNOWN: ( np.zeros((1, 3)), np.expand_dims(np.eye(3), axis=0), np.arange(1), ), # world ObjType.BODY: (d.xipos, d.ximat, np.arange(m.nbody)), ObjType.XBODY: (d.xpos, d.xmat, np.arange(m.nbody)), ObjType.GEOM: (d.geom_xpos, d.geom_xmat, m.geom_bodyid), ObjType.SITE: (d.site_xpos, d.site_xmat, m.site_bodyid), ObjType.CAMERA: (d.cam_xpos, d.cam_xmat, m.cam_bodyid), } stage_vel = m.sensor_needstage == mujoco.mjtStage.mjSTAGE_VEL sensor_types = set(m.sensor_type[stage_vel]) if sensor_types & {SensorType.SUBTREELINVEL, SensorType.SUBTREEANGMOM}: d = smooth.subtree_vel(m, d) sensors, adrs = [], [] for sensor_type in sensor_types: idx = m.sensor_type == sensor_type objid = m.sensor_objid[idx] adr = m.sensor_adr[idx] cutoff = m.sensor_cutoff[idx] data_type = m.sensor_datatype[idx] if sensor_type == SensorType.VELOCIMETER: bodyid = m.site_bodyid[objid] pos = d.site_xpos[objid] rot = d.site_xmat[objid] cvel = d.cvel[bodyid] subtree_com = d.subtree_com[m.body_rootid[bodyid]] sensor = jax.vmap( lambda vec, dif, rot: rot.T @ (vec[3:] - jp.cross(dif, vec[:3])) )(cvel, pos - subtree_com, rot) adr = (adr[:, None] + np.arange(3)[None]).reshape(-1) elif sensor_type == SensorType.GYRO: bodyid = m.site_bodyid[objid] rot = d.site_xmat[objid] ang = d.cvel[bodyid, :3] sensor = jax.vmap(lambda ang, rot: rot.T @ ang)(ang, rot) adr = (adr[:, None] + np.arange(3)[None]).reshape(-1) elif sensor_type == SensorType.JOINTVEL: sensor = d.qvel[m.jnt_dofadr[objid]] elif sensor_type == SensorType.TENDONVEL: sensor = d._impl.ten_velocity[objid] elif sensor_type == SensorType.ACTUATORVEL: sensor = d._impl.actuator_velocity[objid] elif sensor_type == SensorType.BALLANGVEL: jnt_dotadr = m.jnt_dofadr[objid, None] + np.arange(3)[None] sensor = d.qvel[jnt_dotadr] adr = (adr[:, None] + np.arange(3)[None]).reshape(-1) elif sensor_type in {SensorType.FRAMELINVEL, SensorType.FRAMEANGVEL}: objtype = m.sensor_objtype[idx] reftype = m.sensor_reftype[idx] refid = m.sensor_refid[idx] # evaluate for valid object and reference object type pairs for ot, rt in set(zip(objtype, reftype)): idxt = (objtype == ot) & (reftype == rt) objidt = objid[idxt] refidt = refid[idxt] cutofft = cutoff[idxt] xpos, _, _ = objtype_data[ot] xposref, xmatref, _ = objtype_data[rt] xpos = xpos[objidt] xposref = xposref[refidt] xmatref = xmatref[refidt] def _cvel_offset(otype, oid): pos, _, bodyid = objtype_data[otype] pos = pos[oid] bodyid = bodyid[oid] return d.cvel[bodyid], pos - d.subtree_com[m.body_rootid[bodyid]] cvel, offset = _cvel_offset(ot, objidt) cvelref, offsetref = _cvel_offset(rt, refidt) cangvel = cvel[:, :3] cangvelref = cvelref[:, :3] if sensor_type == SensorType.FRAMELINVEL: clinvel = cvel[:, 3:] clinvelref = cvelref[:, 3:] xlinvel = clinvel - jp.cross(offset, cangvel) xlinvelref = clinvelref - jp.cross(offsetref, cangvelref) rvec = xpos - xposref rel_vel = xlinvel - xlinvelref + jp.cross(rvec, cangvelref) sensor = jp.where( (refidt > -1)[:, None], jax.vmap(lambda mat, vec: mat.T @ vec)(xmatref, rel_vel), xlinvel, ) elif sensor_type == SensorType.FRAMEANGVEL: rel_vel = cangvel - cangvelref sensor = jp.where( (refidt > -1)[:, None], jax.vmap(lambda mat, vec: mat.T @ vec)(xmatref, rel_vel), cangvel, ) else: raise ValueError(f'Unknown sensor type: {sensor_type}') adrt = adr[idxt, None] + np.arange(3)[None] sensors.append(_apply_cutoff(sensor, cutofft, data_type[0]).reshape(-1)) adrs.append(adrt.reshape(-1)) continue # avoid adding to sensors/adrs list a second time elif sensor_type == SensorType.SUBTREELINVEL: sensor = d._impl.subtree_linvel[objid] adr = (adr[:, None] + np.arange(3)[None]).reshape(-1) elif sensor_type == SensorType.SUBTREEANGMOM: sensor = d._impl.subtree_angmom[objid] adr = (adr[:, None] + np.arange(3)[None]).reshape(-1) else: # TODO(taylorhowell): raise error after adding sensor check to io.py continue # unsupported sensor type sensors.append(_apply_cutoff(sensor, cutoff, data_type[0]).reshape(-1)) adrs.append(adr) if not adrs: return d sensordata = d.sensordata.at[np.concatenate(adrs)].set( jp.concatenate(sensors) ) return d.replace(sensordata=sensordata)
[docs] def sensor_acc(m: Model, d: Data) -> Data: """Compute acceleration/force-dependent sensors values.""" if not isinstance(m._impl, ModelJAX) or not isinstance(d._impl, DataJAX): raise ValueError('sensor_acc requires JAX backend implementation.') if m.opt.disableflags & DisableBit.SENSOR: return d # position and bodyid by object type objtype_data = { ObjType.UNKNOWN: (np.zeros((1, 3)), np.arange(1)), ObjType.BODY: (d.xipos, np.arange(m.nbody)), ObjType.XBODY: (d.xpos, np.arange(m.nbody)), ObjType.GEOM: (d.geom_xpos, m.geom_bodyid), ObjType.SITE: (d.site_xpos, m.site_bodyid), ObjType.CAMERA: (d.cam_xpos, m.cam_bodyid), } stage_acc = m.sensor_needstage == mujoco.mjtStage.mjSTAGE_ACC sensor_types = set(m.sensor_type[stage_acc]) if sensor_types & { SensorType.ACCELEROMETER, SensorType.FORCE, SensorType.TORQUE, SensorType.FRAMELINACC, SensorType.FRAMEANGACC, }: d = smooth.rne_postconstraint(m, d) contact_intprm = m.sensor_intprm[m.sensor_type == SensorType.CONTACT] contact_maxforce = (contact_intprm[:, 1] == 2).any() contact_dataforce = (contact_intprm[:, 0] & (1 << 1)).any() contact_datatorque = (contact_intprm[:, 0] & (1 << 2)).any() if (m.sensor_type[stage_acc] == SensorType.TOUCH).any() | ( (m.sensor_type[stage_acc] == SensorType.CONTACT).any() and (contact_maxforce | contact_dataforce | contact_datatorque) ): # compute contact forces contact_force = [] condim_ids = [] for dim in set(d._impl.contact.dim): force, condim_id = support.contact_force_dim(m, d, dim) contact_force.append(force) condim_ids.append(condim_id) contact_force = jp.concatenate(contact_force)[ np.argsort(np.concatenate(condim_ids)) ] sensors, adrs = [], [] for sensor_type in sensor_types: idx = m.sensor_type == sensor_type objid = m.sensor_objid[idx] adr = m.sensor_adr[idx] cutoff = m.sensor_cutoff[idx] data_type = m.sensor_datatype[idx] if sensor_type == SensorType.TOUCH: # get bodies of contact geoms conbody = jp.array(m.geom_bodyid)[d._impl.contact.geom] # get site information site_bodyid = m.site_bodyid[objid] site_size = m.site_size[objid] site_xpos = d.site_xpos[objid] site_xmat = d.site_xmat[objid] site_type = m.site_type[objid] conbody0 = site_bodyid[:, None] == conbody[:, 0] conbody1 = site_bodyid[:, None] == conbody[:, 1] contacts = (d._impl.contact.efc_address >= 0)[None] & ( conbody0 | conbody1 ) # compute conray, flip if second body conray = jax.vmap( lambda frame, force: math.normalize(frame[0] * force[0]) )(d._impl.contact.frame, contact_force) conray = jp.where(conbody1[..., None], -conray, conray) # compute distance, mapping over sites and contacts def _distance(site_size, site_xpos, site_xmat, site_type, pos, conray): def dist(size, xpos, xmat, conray): pnt = (pos - xpos) @ xmat vec = conray @ xmat ray_geom_ = lambda pnt, vec: ray.ray_geom(size, pnt, vec, site_type) return jax.vmap(ray_geom_)(pnt, vec) return jax.vmap(dist)(site_size, site_xpos, site_xmat, conray) dist = [] dist_id = [] for st in set(site_type): (dist_id_site,) = np.nonzero(st == site_type) dist_site = _distance( site_size[dist_id_site], site_xpos[dist_id_site], site_xmat[dist_id_site], st, d._impl.contact.pos, conray[dist_id_site], ) dist.append(jp.where(jp.isinf(dist_site), 0, dist_site)) dist_id.append(dist_id_site) dist = jp.vstack(dist)[np.argsort(np.concatenate(dist_id))] # accumulate normal forces for each site sensor = jp.dot((dist > 0) & contacts, contact_force[:, 0]) elif sensor_type == SensorType.CONTACT: # maximum number of contacts ncon = d._impl.ncon # active contacts dist = d._impl.contact.dist pos = dist - d._impl.contact.includemargin is_contact = pos < 0 # reduction criteria if contact_maxforce: # compute force magnitude for each contact force_mag = jax.vmap( lambda forcetorque: jp.dot(forcetorque[:3], forcetorque[:3]) )(contact_force) def _reduce(reduction, mask): if reduction == 1: # mindist return jp.argsort(pos * mask, descending=False) if reduction == 2: # maxforce return jp.argsort(force_mag * mask, descending=True) return jp.arange(mask.size) # number of data elements per slot def nslotdata(dataspec): size = 0 # found, force, torque, dist, pos, normal, tangent # TODO(taylorhowell): get sizes from mjCONDATA_SIZE for i, size_i in enumerate([1, 3, 3, 1, 3, 3, 3]): if dataspec & (1 << i): size += size_i return size dataspecs, reduces, _ = m.sensor_intprm[idx].T dims = m.sensor_dim[idx] objtypes = m.sensor_objtype[idx] refid = m.sensor_refid[idx] reftypes = m.sensor_reftype[idx] for dataspec, reduce, objtype, reftype, dim in set( zip(dataspecs, reduces, objtypes, reftypes, dims) ): idx_ds = ( (dataspec == dataspecs) & (reduce == reduces) & (objtype == objtypes) & (reftype == reftypes) & (dim == dims) ) # TODO(taylorhowell): site filter size = nslotdata(dataspec) num = np.minimum(int(dim / size), ncon) nsensor = idx_ds.sum() if objtype == ObjType.UNKNOWN and reftype == ObjType.UNKNOWN: # all contacts match match = np.ones(ncon, dtype=np.bool) # matched and reduced contact ids sort = _reduce(reduce, match) cid = sort[:num] # number of contacts per sensor nfound = sum(is_contact) # if duplicate sensor cid = jp.tile(cid, (nsensor,)) nfound = jp.tile(nfound, (nsensor,)) flip = jp.ones((cid.size, 3)) elif objtype == ObjType.GEOM or reftype == ObjType.GEOM: sensorid1 = objid[idx_ds] sensorid2 = refid[idx_ds] geomid0 = d._impl.contact.geom[:, 0] geomid1 = d._impl.contact.geom[:, 1] # match sensor ids and contact geom ids geom0id1 = geomid0 == sensorid1[:, None] geom0id2 = geomid0 == sensorid2[:, None] geom1id1 = geomid1 == sensorid1[:, None] geom1id2 = geomid1 == sensorid2[:, None] if objtype == ObjType.GEOM and reftype == ObjType.UNKNOWN: # geom1 mask12 = geom0id1 mask21 = geom1id1 elif objtype == ObjType.UNKNOWN and reftype == ObjType.GEOM: # geom2 mask12 = geom0id2 mask21 = geom1id2 else: # geom1, geom2 mask12 = geom0id1 & geom1id2 mask21 = geom0id2 & geom1id1 match = mask12 | mask21 # matched and reduced contact ids cid = jax.vmap(lambda x: _reduce(reduce, x))(match)[:, :num] cid = cid.reshape(-1) # flip direction for force, torque, normal, tangent if reftype == ObjType.UNKNOWN: # geom1 is_flip = (geomid1[cid] == np.repeat(sensorid1, num))[:, None] elif objtype == ObjType.UNKNOWN: # geom2 is_flip = (geomid0[cid] == np.repeat(sensorid2, num))[:, None] else: # geom1, geom2 is_flip = np.repeat(sensorid1 > sensorid2, num)[:, None] flip = jp.where( is_flip, jp.array([[1, 1, -1]]), jp.array([[1, 1, 1]]), ) # number of contacts per sensor nfound = (match * is_contact[None, :]).sum(axis=1) # TODO(taylorhowell): matching criteria: body, subtree else: raise NotImplementedError( f'Unsupported contact sensor semantics: {objtype} {reftype}.' ) slot = [] if dataspec & (1 << 0): # found slot.append(jp.repeat(nfound, num)[:, None]) if dataspec & (1 << 1): # force slot.append(flip * contact_force[cid, :3]) if dataspec & (1 << 2): # torque slot.append(flip * contact_force[cid, 3:]) if dataspec & (1 << 3): # dist slot.append(dist[cid, None]) if dataspec & (1 << 4): # pos slot.append(d._impl.contact.pos[cid]) if dataspec & (1 << 5): # normal slot.append(flip[:, 2, None] * d._impl.contact.frame[cid, 0]) if dataspec & (1 << 6): # tangent slot.append(flip[:, 2, None] * d._impl.contact.frame[cid, 1]) found = jp.tile(jp.arange(num), nsensor) < jp.repeat(nfound, num) sensors.append((found[:, None] * jp.hstack(slot)).reshape(-1)) adrs.append( (adr[idx_ds][:, None] + np.arange(num * size)[None]).reshape(-1) ) continue # avoid adding to sensors/adrs list a second time elif sensor_type == SensorType.ACCELEROMETER: @jax.vmap def _accelerometer(cvel, cacc, diff, rot): ang = rot.T @ cvel[:3] lin = rot.T @ (cvel[3:] - jp.cross(diff, cvel[:3])) acc = rot.T @ (cacc[3:] - jp.cross(diff, cacc[:3])) correction = jp.cross(ang, lin) return acc + correction bodyid = m.site_bodyid[objid] rot = d.site_xmat[objid] cvel = d.cvel[bodyid] cacc = d._impl.cacc[bodyid] dif = d.site_xpos[objid] - d.subtree_com[m.body_rootid[bodyid]] sensor = _accelerometer(cvel, cacc, dif, rot) adr = (adr[:, None] + np.arange(3)[None]).reshape(-1) elif sensor_type == SensorType.FORCE: bodyid = m.site_bodyid[objid] cfrc_int = d._impl.cfrc_int[bodyid] site_xmat = d.site_xmat[objid] sensor = jax.vmap(lambda mat, vec: mat.T @ vec)( site_xmat, cfrc_int[:, 3:] ) adr = (adr[:, None] + np.arange(3)[None]).reshape(-1) elif sensor_type == SensorType.TORQUE: bodyid = m.site_bodyid[objid] rootid = m.body_rootid[bodyid] cfrc_int = d._impl.cfrc_int[bodyid] site_xmat = d.site_xmat[objid] dif = d.site_xpos[objid] - d.subtree_com[rootid] sensor = jax.vmap( lambda vec, dif, rot: rot.T @ (vec[:3] - jp.cross(dif, vec[3:])) )(cfrc_int, dif, site_xmat) adr = (adr[:, None] + np.arange(3)[None]).reshape(-1) elif sensor_type == SensorType.ACTUATORFRC: sensor = d.actuator_force[objid] elif sensor_type == SensorType.JOINTACTFRC: sensor = d.qfrc_actuator[m.jnt_dofadr[objid]] elif sensor_type == SensorType.TENDONACTFRC: force_mask = [ (m.actuator_trntype == TrnType.TENDON) & (m.actuator_trnid[:, 0] == tendon_id) for tendon_id in objid ] force_ids = np.concatenate([np.nonzero(mask)[0] for mask in force_mask]) force_mat = np.array(force_mask)[:, force_ids] sensor = force_mat @ d.actuator_force[force_ids] elif sensor_type in (SensorType.FRAMELINACC, SensorType.FRAMEANGACC): objtype = m.sensor_objtype[idx] for ot in set(objtype): idxt = objtype == ot objidt = objid[idxt] pos, bodyid = objtype_data[ot] pos = pos[objidt] bodyid = bodyid[objidt] cacc = d._impl.cacc[bodyid] if sensor_type == SensorType.FRAMELINACC: @jax.vmap def _framelinacc(cvel, cacc, offset): ang = cvel[:3] lin = cvel[3:] - jp.cross(offset, cvel[:3]) acc = cacc[3:] - jp.cross(offset, cacc[:3]) correction = jp.cross(ang, lin) return acc + correction cvel = d.cvel[bodyid] offset = pos - d.subtree_com[m.body_rootid[bodyid]] sensor = _framelinacc(cvel, cacc, offset).reshape(-1) elif sensor_type == SensorType.FRAMEANGACC: sensor = cacc[:, :3].reshape(-1) else: raise ValueError(f'Unknown sensor type: {sensor_type}') adrt = adr[idxt, None] + np.arange(3)[None] sensors.append(sensor.reshape(-1)) adrs.append(adrt.reshape(-1)) continue # avoid adding to sensors/adrs list a second time else: # TODO(taylorhowell): raise error after adding sensor check to io.py continue # unsupported sensor type sensors.append(_apply_cutoff(sensor, cutoff, data_type[0]).reshape(-1)) adrs.append(adr) if not adrs: return d sensordata = d.sensordata.at[np.concatenate(adrs)].set( jp.concatenate(sensors) ) return d.replace(sensordata=sensordata)